本阶段的任务是将训练数据和测试数据进行预处理和创建加载器,以供后面的网络使用。
新建load_cifar.py:
import glob
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
label_dict = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, # 类别标签对应的数字
'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
# print(label_dict)
def default_loader(path): # 定义图片加载函数
return Image.open(path).convert('RGB') # 转换成RGB模式
train_transform = transforms.Compose([ # 训练集数据预处理
# transforms.Resize((32, 32)), # 将图像大小调整为32x32
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomVerticalFlip(), # 随机垂直翻转
# transforms.RandomRotation(90), # 随机旋转90度
# transforms.RandomGrayscale(p=0.1), # 随机将图像转换为灰度图,p=0.1表示有10%的概率执行该操作
# transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), # 调整图像的亮度、对比度、饱和度和色调
transforms.ToTensor(), # 将图像转换为Tensor
transforms.Normalize((0.4914, 0.4822, 0.4465), # 标准化
(0.2023, 0.1994, 0.2010))
])
# (0.4914, 0.4822, 0.4465)是均值,(0.2023, 0.1994, 0.2010)是标准差,这两组数字是根据训练集数据计算出来的
# 计算方法见:https://blog.csdn.net/xulibo5828/article/details/143143550
test_transform = transforms.Compose([ # 测试集数据预处理
# transforms.CenterCrop((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))])
class MyDataset(Dataset): # 自定义数据集
def __init__(self, img_list, # 图片的地址列表
transform=None, # 数据预处理
loader=default_loader): # 图片加载函数
super(MyDataset, self).__init__()
imgs = [] # 图片列表
for img_path in img_list:
# 图像文件的地址,典型格式为:
# E:\\AI_tset\\cifar10_demo\\cifar-10-python\\cifar-10-batches-py\\train\\ship\\abandoned_ship_s_000004.png
im_label_name = img_path.split('\\')[-2] # 图片所属类别的名称,这里使用的是绝对路径,文件目录分隔符为反斜杠,使用相对路径则为正斜杠
imgs.append([img_path, label_dict[im_label_name]]) # 将图片路径和对应的类别标签添加到列表中
self.imgs = imgs # 图片列表
self.transform = transform # 数据预处理
self.loader = loader # 图片加载函数
def __getitem__(self, idx): # 获取图片数据 # 请注意,这个是PyTorch的Dataset类中必须实现的方法
img_path, label = self.imgs[idx] # 获取图片路径和对应的类别标签
im_data = self.loader(img_path) # 加载图片,并得到图像的数据
if self.transform: # 如果有定义数据预处理
im_data = self.transform(im_data) # 对图像进行预处理,转换为Tensor等
return im_data, label # 返回图片数据和对应的类别标签
def __len__(self): # 返回数据集的长度 # 请注意,这个也是PyTorch的Dataset类中必须实现的方法
return len(self.imgs) # 返回图片列表的长度
# 获取训练集的文件名
train_list = glob.glob('E:\\AI_tset\\cifar10_demo\\cifar-10-python\\cifar-10-batches-py\\train\\*\\*.png') # 获取训练集的文件名
# 获取测试集的文件名
test_list = glob.glob('E:\\AI_tset\\cifar10_demo\\cifar-10-python\\cifar-10-batches-py\\test\\*\\*.png') # 获取测试集的文件名
# print(len(train_list)) # 50000
print(test_list[:5]) # 10000
# 定义训练数据集
trans_dataSet = MyDataset(img_list=train_list, transform=train_transform) # 自定义的数据集,地址为训练集的文件名,数据预处理为transform
# print(trans_dataSet.__len__()) # 50000
# 定义测试数据集
test_dataSet = MyDataset(img_list=test_list, transform=test_transform) # 自定义的数据集,地址为测试集的文件名,数据预处理为test_transform
# print(test_dataSet.__len__()) # 10000
# 定义训练集的加载器
train_loader = DataLoader(dataset=trans_dataSet, batch_size=128, shuffle=True,
num_workers=8) # 以随机顺序加载训练数据集,num_workers表示加载数据的子进程数量
# 定义测试集的加载器
test_loader = DataLoader(dataset=test_dataSet, batch_size=128, shuffle=False,
num_workers=8) # 顺序加载测试集数据,num_workers表示加载数据的子进程数量
# print("num_of_train", len(train_loader)) # 391(50000/128),相当于有391个batch,每个batch有128个样本
# print("num_of_test", len(test_loader)) # 79(10000/128),相当于有79个batch,每个batch有128个样本