项目目录
custom_dataset.py
from torch.utils.data import Dataset
import os
from torchvision.io import image
# 所有数据集都要继承Dataset类
class MyData(Dataset):
def __init__(self, root_dir, label_dir, transform=None):
# self 指定了一个类当中的全局变量,该变量可以让后面的函数使用
self.root_dir = root_dir
self.label_dir = label_dir
self.transform = transform
self.path = os.path.join(self.root_dir, self.label_dir)
# 这里的图片路径不是一张图片的路径,是一个路径数组
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
#读取图片,读取后魏tensor类型,该方法只支持读取png格式图片和jpeg格式图片
img = image.read_image(img_item_path)
label = self.label_dir
if label == "ants":
label = 0
else:
label = 1
if self.transform is not None:
img = self.transform(img) # 对图片进行某些变换
return img, label
def __len__(self):
return len(self.img_path)
train.py
from torch.utils.data import DataLoader
from torchvision import transforms
from custom_dataset import MyData
transform = transforms.Resize((224, 224))
# 训练集的路径
train_root_dir = "dataset/train"
# 验证集路径
val_root_dir = "dataset/val"
# 分类的label
ants_label_dir = "ants"
bees_label_dir = "bees"
# 创建训练集
train_dataset = train_dataset = MyData(train_root_dir, ants_label_dir, transform) + MyData(train_root_dir, bees_label_dir, transform)
# 创建验证集
val_dataset = MyData(val_root_dir, ants_label_dir, transform) + MyData(val_root_dir, bees_label_dir, transform)
# 加载训练集
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
# 加载验证集
val_loader = DataLoader(val_dataset, batch_size=5, shuffle=True)
if __name__ == "__main__":
#输出训练集
for batchidx,(imgs,labels) in enumerate(train_loader):
print("训练集的第{}个batch,他的shape是:{},他的label是:{}".format((batchidx+1),imgs.shape,labels))
#输出验证集
for batchidx,(imgs,labels) in enumerate(val_loader):
print("验证集的第{}个batch,他的shape是:{},他的label是:{}".format((batchidx+1),imgs.shape,labels))
**写在后面:**上诉自定义数据集中方法有很多,最重要的就是要把图片转成tensor,还有就是要做resize不然会报错,其次是关于label的问题,在此处因为只有两个分类所以我就简单做了一下判断给label赋了值,如果label较多的话可以用csv文件的形式去读取图片与其对应的label。
项目github:https://github.com/Sjyzheishuai/CustomDataset