1.parser = argparse.ArgumentParser()
argparse是python用于解析命令行参数和选项的标准模块,用于代替已经过时的optparse模块。argparse模块的作用是用于解析命令行参数。
我们很多时候,需要用到解析命令行参数的程序,目的是在终端窗口(ubuntu是终端窗口,windows是命令行窗口)输入训练的参数和选项。
使用步骤
我们常常可以把argparse的使用简化成下面四个步骤
1:import argparse
2:parser = argparse.ArgumentParser()
3:parser.add_argument()
4:parser.parse_args()
上面四个步骤解释如下:首先导入该模块;然后创建一个解析对象;然后向该对象中添加你要关注的命令行参数和选项,每一个add_argument方法对应一个你要关注的参数或选项;最后调用parse_args()方法进行解析;解析成功之后即可使用。
2.定义数据集模型
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
class Mydataset(Dataset):
"""自定义数据集"""
def __init__(self,images_path,images_class,transform=None):
self.images_path = images_path #图像路径
self.images_class = images_class #图像种类
self.transform = transform #数据预处理
def __getitem__(self, index):
img = Image.open(self.images_path[index])
if img.mode != 'RGB' :
raise ValueError("image:{} isn't RGB mode.".format(self.images_path[index])) #若不是RGB图像抛出异常
label = self.images_class[index]
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.images_path)
def collatr_fn(batch):
images,labels = tuple(zip(*batch))
images = torch.stack(images,dim=0)
labels = torch.as_tensor(labels)
return images,labels
将上述代码可以创立成名为my_dataset的.py文件,方便调用。
上述my_datase.py文件中代码大致意思为:定义一个自定义数据集类,获取自定义数据集中的图像路径、图像种类、及图像预处理方式,通过一系列操作,最后返回图像的路径及对应的标签。
3.定义数据集中图像预处理方式及实例化
images_size = 224
data_transform = {
"train":transforms.Compose([transforms.RandomResizedCrop(images_size), #先随机采集,然后对裁剪得到的图像缩放为同一大小
transforms.RandomHorizontalFlip(), #以给定的概率随机水平旋转给定的PIL的图像,默认为0.5
transforms.ToTensor(), #将给定图像转为Tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#标准化,均值为0,标准差为1
"val":transforms.Compose([transforms.Resize(int(images_size * 1.143)), #将图片短边缩放至images_size*1.143,长宽比保持不变
transforms.CenterCrop(images_size), #将图片从中心裁剪成images_size大小
transforms.ToTensor(), #将给定图像转为Tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} #标准化,均值为0,标准差为1
# 实例化训练数据集
train_dataset = Mydataset(images_path='',
images_class=1,
transform = data_transform["train"])
# 实例化验证数据集
val_dataset = Mydataset(images_path='',
images_class=1,
transform = data_transform["val"])
batch_size = args.batch.size
nw = min(os.cpu_count(),batch_size if batch_size > 1 else 0,8)
print("Using {} dataloader workers every process".format(nw))
train_loader = DataLoader(train_dataset, #处理好的所有数据
batch_size = batch_size, #批次数量
shuffle = True, #打乱数据
num_workers = nw, #加载数据的线程数
collate_fn = train_dataset.collatr_fn, #batch的样本打包成一个tensor的结构
pin_memory = True) #将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu
val_loader = DataLoader(val_dataset, #处理好的所有数据
batch_size = batch_size, #批次数量
shuffle = False, #打乱数据
num_workers = nw, #加载数据的线程数
collate_fn = val_dataset.collatr_fn, #batch的样本打包成一个tensor的结构
pin_memory = True) #将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu
未完待续!!!