PyTorch之基本配置&数据处理
2.1 深度学习任务的整体流程
2.1.1 机器学习任务步骤
- 首先对数据进行预处理,包括数据的统一和必要的数据变换
- 划分训练集和测试集
- 选择模型,设定损失函数和优化方法以及对应的超参数
- 可以使用sklearn这样的机器学习库中模型自带的损失函数和优化器
- 用模型去拟合训练集数据,并在验证集/测试集上计算模型表现
2.1.2 深度学习任务步骤
- 和机器学习流程类似,但在代码实现上有较大差异
- 首先载入数据,深度学习所需样本量很大,一次加载全部数据运行可能会超出内存容量而无法实现
- 为提高模型表现,利用批(batch)训练来提高模型表现,需要每次训练读取固定数量的样本送入模型中训练
- 划分训练集和测试集
- 搭建模型,需要“逐层搭建”或者预先定义好可以实现特定功能的模块,再将这些模块组装起来
- 深度学习有一些用于实现特定功能的层(如卷积层、池化层、批正则化层、LSTM层等)
- 设定损失函数和优化器,这部分和实现经典机器学习类似
- 由于模型设定的灵活性,因此损失函数和优化器要能够保证反向传播能够在用户自行定义的模型结构上实现
- 开始训练
- 涉及配置多卡GPU的内容
- 总结:
- 深度学习中训练和验证过程最大的特点在于读入数据是按批的,每次读入一个批次的数据,放入GPU中训练,然后将损失函数反向传播回网络最前面的层,同时使用优化器调整网络参数。这里会涉及到各个模块配合的问题。训练/验证后还需要根据设定好的指标计算模型表现
2.1.3 PyTorch基本配置
- 导入包后可以统一设置以下几个超参数,方便后续调试时修改
- batch size
- 初始学习率(初始)
- 训练轮次(max_epochs)
batch_size = 16 #批次的大小
lr = 1e-4 #优化器的学习率
max_epochs = 100 #训练轮次
- GPU设置
# 方案一:使用os.environ,这种情况如果使用GPU不需要设置
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
2.2 数据读取
- PyTorch的数据读入是通过Dataset+DataLoader的方式完成的
- Dataset定义数据的格式和数据变换形式
- DataLoader用iterative的方式不断读入批次数据
2.2.1 自主定义Dataset类
- 定义的类需继承PyTorch自身的Dataset类,主要包含三个函数:
__init__
: 用于向类传入外部参数,同时定义样本集__getitem__
: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__
: 用于返回数据集的样本数
- 以Cifar10数据集为例给出构建Dataset类的方式
import torch
from torchvision import datasets
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)
- 使用自带的ImageFolder类:用于读取按一定结构存储的图片数据(path对应图片存放的目录,目录下包含若干子目录,每个子目录对应属于同一个类的图片)
data_transform
可以对图像进行一定的变换,如翻转、裁剪等操作,可自己定义
- 再举一个例子(图片存放在一个文件夹,另外有一个csv文件给出图片名称对应的标签,这种情况需要自己定义Dataset类)
class MyDatast(Dataset):
det __init__(self, data_dir, info_csv, image_list, transform=None):
"""
Args:
data_dir: path to image directory.
info_csv: path to the csv file containing image indexes with corresponding labels
image_list: path to the txt file contains image names to training/validation set
transform: optional transform to be applied on a sample.
"""
label_info = pd.read_csv(info_csv)
image_file = open(image_list).readlines()
self.data_dir = data_dir
self.image_file = image_file
self.label_info = label_info
self.transform = transform
def __ggetitem__(self, index):
"""
Args:
index:the index of item
Returns:
image and its labels
"""
image_name = self.image_file[index].strip('\n')
raw_label = self.label_info.loc[self.label_info['Image_index'] == image_name]
label = raw_label.iloc[:,0]
image_name = os.path.join(self.data_dir, image_name)
image = Image.open(image_name).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.image_file)
2.2.2 使用DataLoader按批次读入数据
from torch.utils.data import DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)
- 解释:
- batch_size:样本是按“批”读入的,batch_size是每次读入的样本数
- num_workers:有多少个进程用于读取数据
- shuffle:是否将读入的数据打乱
- drop_last:对于样本最后一部分没有达到批次数的样本,使其不再参与训练
- 查看加载的数据(PyTorch中DataLoader读取可以使用next和iter来完成)
import matplotlib.pyplot as plt
images, labels = next(iter(val_loader))
print(images.shape)
plt.imshow(images[0].transpose(1,2,0))
plt.show()
2.3 数据增强
-
处理的目的:
-
增强模型鲁棒性
-
扩充数据容量
-
2.3.1 反转&旋转&缩放&裁剪
- 反转
new_im = transforms.RandomHorizontalFlip(p=1)(im) #p表示概率
new_im.save(os.path.join(outfile, '1_1.jpg'))
new_im = transforms.RandomVerticalFlip(p=1)(im)
new_im.save(os.path.join(outfile, '1_2.jpg'))
-
旋转
new_im = transforms.RandomRotation(45)(im) #随即旋转45度
-
缩放
new_im = transforms.Resize((100, 200))(im)
-
裁剪
new_im = transforms.RandomCrop(100)(im) #裁剪出100×100的区域
new_im.save(os.path.join(outfile, '4_1.jpg'))
new_im = transforms.CencerCrop(100)(im) #中心裁剪
new_im.save(os.path.join(outfile, '4_2.jpg'))
2.3.2 亮度&对比度&饱和度
- 亮度
new_im = transforms.ColorJitter(brightness=1)(im)
- 对比度
new_im = transforms.ColorJitter(contrast=1)(im)
- 饱和度
new_im = transforms.ColorJitter(saturation=0.5)(im)
资料参考来源:1. Datawhale社区《深入浅出PyTorch教程》 2. 有三AI《PyTorch入门及实战》 3. 其他零散网络资源