pytorh学习笔记——cifar10(二)加载数据

 本阶段的任务是将训练数据和测试数据进行预处理和创建加载器,以供后面的网络使用。

新建load_cifar.py:

import glob
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image

label_dict = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,  # 类别标签对应的数字
              'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}


# print(label_dict)


def default_loader(path):  # 定义图片加载函数
    return Image.open(path).convert('RGB')  # 转换成RGB模式


train_transform = transforms.Compose([  # 训练集数据预处理
    # transforms.Resize((32, 32)),  # 将图像大小调整为32x32
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomVerticalFlip(),  # 随机垂直翻转
    # transforms.RandomRotation(90),  # 随机旋转90度
    # transforms.RandomGrayscale(p=0.1),  # 随机将图像转换为灰度图,p=0.1表示有10%的概率执行该操作
    # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),  # 调整图像的亮度、对比度、饱和度和色调
    transforms.ToTensor(),  # 将图像转换为Tensor
    transforms.Normalize((0.4914, 0.4822, 0.4465),  # 标准化
                         (0.2023, 0.1994, 0.2010))
])
# (0.4914, 0.4822, 0.4465)是均值,(0.2023, 0.1994, 0.2010)是标准差,这两组数字是根据训练集数据计算出来的
# 计算方法见:https://blog.csdn.net/xulibo5828/article/details/143143550

test_transform = transforms.Compose([  # 测试集数据预处理
    # transforms.CenterCrop((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))])


class MyDataset(Dataset):  # 自定义数据集

    def __init__(self, img_list,  # 图片的地址列表
                 transform=None,  # 数据预处理
                 loader=default_loader):  # 图片加载函数
        super(MyDataset, self).__init__()
        imgs = []  # 图片列表

        for img_path in img_list:
            # 图像文件的地址,典型格式为:
            # E:\\AI_tset\\cifar10_demo\\cifar-10-python\\cifar-10-batches-py\\train\\ship\\abandoned_ship_s_000004.png
            im_label_name = img_path.split('\\')[-2]  # 图片所属类别的名称,这里使用的是绝对路径,文件目录分隔符为反斜杠,使用相对路径则为正斜杠
            imgs.append([img_path, label_dict[im_label_name]])  # 将图片路径和对应的类别标签添加到列表中
        self.imgs = imgs  # 图片列表
        self.transform = transform  # 数据预处理
        self.loader = loader  # 图片加载函数

    def __getitem__(self, idx):  # 获取图片数据  # 请注意,这个是PyTorch的Dataset类中必须实现的方法
        img_path, label = self.imgs[idx]  # 获取图片路径和对应的类别标签
        im_data = self.loader(img_path)  # 加载图片,并得到图像的数据

        if self.transform:  # 如果有定义数据预处理
            im_data = self.transform(im_data)  # 对图像进行预处理,转换为Tensor等

        return im_data, label  # 返回图片数据和对应的类别标签

    def __len__(self):  # 返回数据集的长度  # 请注意,这个也是PyTorch的Dataset类中必须实现的方法
        return len(self.imgs)  # 返回图片列表的长度


# 获取训练集的文件名
train_list = glob.glob('E:\\AI_tset\\cifar10_demo\\cifar-10-python\\cifar-10-batches-py\\train\\*\\*.png')  # 获取训练集的文件名
# 获取测试集的文件名
test_list = glob.glob('E:\\AI_tset\\cifar10_demo\\cifar-10-python\\cifar-10-batches-py\\test\\*\\*.png')  # 获取测试集的文件名

# print(len(train_list))  # 50000
print(test_list[:5])  # 10000

# 定义训练数据集
trans_dataSet = MyDataset(img_list=train_list, transform=train_transform)  # 自定义的数据集,地址为训练集的文件名,数据预处理为transform
# print(trans_dataSet.__len__())  # 50000

# 定义测试数据集
test_dataSet = MyDataset(img_list=test_list, transform=test_transform)  # 自定义的数据集,地址为测试集的文件名,数据预处理为test_transform
# print(test_dataSet.__len__())  # 10000

# 定义训练集的加载器
train_loader = DataLoader(dataset=trans_dataSet, batch_size=128, shuffle=True,
                          num_workers=8)  # 以随机顺序加载训练数据集,num_workers表示加载数据的子进程数量

# 定义测试集的加载器
test_loader = DataLoader(dataset=test_dataSet, batch_size=128, shuffle=False,
                         num_workers=8)  # 顺序加载测试集数据,num_workers表示加载数据的子进程数量
# print("num_of_train", len(train_loader))  # 391(50000/128),相当于有391个batch,每个batch有128个样本
# print("num_of_test", len(test_loader))  # 79(10000/128),相当于有79个batch,每个batch有128个样本

猜你喜欢

转载自blog.csdn.net/xulibo5828/article/details/142987477