使用pytorch制作自定义数据集并用DataLoader加载

项目目录
在这里插入图片描述
custom_dataset.py

from torch.utils.data import Dataset
import os
from torchvision.io import image


# 所有数据集都要继承Dataset类
class MyData(Dataset):
    def __init__(self, root_dir, label_dir, transform=None):
        # self 指定了一个类当中的全局变量,该变量可以让后面的函数使用
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.transform = transform
        self.path = os.path.join(self.root_dir, self.label_dir)
        # 这里的图片路径不是一张图片的路径,是一个路径数组
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        #读取图片,读取后魏tensor类型,该方法只支持读取png格式图片和jpeg格式图片
        img = image.read_image(img_item_path)
        label = self.label_dir
        if label == "ants":
            label = 0
        else:
            label = 1
        if self.transform is not None:
            img = self.transform(img)  # 对图片进行某些变换
        return img, label

    def __len__(self):
        return len(self.img_path)

train.py

from torch.utils.data import DataLoader
from torchvision import transforms
from custom_dataset import MyData

transform = transforms.Resize((224, 224))
# 训练集的路径
train_root_dir = "dataset/train"
# 验证集路径
val_root_dir = "dataset/val"
# 分类的label
ants_label_dir = "ants"
bees_label_dir = "bees"

# 创建训练集
train_dataset = train_dataset = MyData(train_root_dir, ants_label_dir, transform) + MyData(train_root_dir, bees_label_dir, transform)
# 创建验证集
val_dataset = MyData(val_root_dir, ants_label_dir, transform) + MyData(val_root_dir, bees_label_dir, transform)

# 加载训练集
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
# 加载验证集
val_loader = DataLoader(val_dataset, batch_size=5, shuffle=True)


if __name__ == "__main__":
    #输出训练集
    for batchidx,(imgs,labels) in enumerate(train_loader):
        print("训练集的第{}个batch,他的shape是:{},他的label是:{}".format((batchidx+1),imgs.shape,labels))
    #输出验证集
    for batchidx,(imgs,labels) in enumerate(val_loader):
        print("验证集的第{}个batch,他的shape是:{},他的label是:{}".format((batchidx+1),imgs.shape,labels))

**写在后面:**上诉自定义数据集中方法有很多,最重要的就是要把图片转成tensor,还有就是要做resize不然会报错,其次是关于label的问题,在此处因为只有两个分类所以我就简单做了一下判断给label赋了值,如果label较多的话可以用csv文件的形式去读取图片与其对应的label。
项目github:https://github.com/Sjyzheishuai/CustomDataset

猜你喜欢

转载自blog.csdn.net/weixin_44747173/article/details/127445974
今日推荐