图像分类网络搭建的一些函数及相关配置文件(一)

1.parser = argparse.ArgumentParser()

argparse是python用于解析命令行参数和选项的标准模块,用于代替已经过时的optparse模块。argparse模块的作用是用于解析命令行参数。

我们很多时候,需要用到解析命令行参数的程序,目的是在终端窗口(ubuntu是终端窗口,windows是命令行窗口)输入训练的参数和选项。

使用步骤

我们常常可以把argparse的使用简化成下面四个步骤

1:import argparse

2:parser = argparse.ArgumentParser()

3:parser.add_argument()

4:parser.parse_args()  

上面四个步骤解释如下:首先导入该模块;然后创建一个解析对象;然后向该对象中添加你要关注的命令行参数和选项,每一个add_argument方法对应一个你要关注的参数或选项;最后调用parse_args()方法进行解析;解析成功之后即可使用。

2.定义数据集模型

import torch  
from torch.utils.data import Dataset  
from PIL import Image  
from torchvision import transforms  

class Mydataset(Dataset):  
    """自定义数据集"""  
    def __init__(self,images_path,images_class,transform=None):  
        self.images_path = images_path                   #图像路径  
        self.images_class = images_class                 #图像种类  
        self.transform = transform                       #数据预处理  
  
  
    def __getitem__(self, index):  
        img = Image.open(self.images_path[index])  
        if img.mode != 'RGB' :  
            raise ValueError("image:{} isn't RGB mode.".format(self.images_path[index])) #若不是RGB图像抛出异常  
        label = self.images_class[index]  
        if self.transform is not None:  
            img = self.transform(img)  
  
        return img,label  
  
    def __len__(self):  
        return len(self.images_path)  
  
    def collatr_fn(batch):  
        images,labels = tuple(zip(*batch))  
        images = torch.stack(images,dim=0)  
        labels = torch.as_tensor(labels)  
        return images,labels

将上述代码可以创立成名为my_dataset的.py文件,方便调用。

上述my_datase.py文件中代码大致意思为:定义一个自定义数据集类,获取自定义数据集中的图像路径、图像种类、及图像预处理方式,通过一系列操作,最后返回图像的路径及对应的标签。

3.定义数据集中图像预处理方式及实例化

images_size = 224  
data_transform = {  
    "train":transforms.Compose([transforms.RandomResizedCrop(images_size),                          #先随机采集,然后对裁剪得到的图像缩放为同一大小  
                                transforms.RandomHorizontalFlip(),                                  #以给定的概率随机水平旋转给定的PIL的图像,默认为0.5  
                                transforms.ToTensor(),                                              #将给定图像转为Tensor  
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#标准化,均值为0,标准差为1  
    "val":transforms.Compose([transforms.Resize(int(images_size * 1.143)),                          #将图片短边缩放至images_size*1.143,长宽比保持不变  
                              transforms.CenterCrop(images_size),                                   #将图片从中心裁剪成images_size大小  
                              transforms.ToTensor(),                                                #将给定图像转为Tensor  
                              transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} #标准化,均值为0,标准差为1  
# 实例化训练数据集  
train_dataset = Mydataset(images_path='',  
                          images_class=1,  
                          transform = data_transform["train"])  
# 实例化验证数据集  
val_dataset = Mydataset(images_path='',  
                        images_class=1,  
                        transform = data_transform["val"])  
batch_size = args.batch.size  
nw = min(os.cpu_count(),batch_size if batch_size > 1 else 0,8)  
print("Using {} dataloader workers every process".format(nw))  
  
train_loader = DataLoader(train_dataset,                         #处理好的所有数据  
                          batch_size = batch_size,               #批次数量  
                          shuffle = True,                        #打乱数据  
                          num_workers = nw,                      #加载数据的线程数  
                          collate_fn = train_dataset.collatr_fn, #batch的样本打包成一个tensor的结构  
                          pin_memory = True)                     #将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu  
val_loader = DataLoader(val_dataset,                         #处理好的所有数据  
                          batch_size = batch_size,               #批次数量  
                          shuffle = False,                        #打乱数据  
                          num_workers = nw,                      #加载数据的线程数  
                          collate_fn = val_dataset.collatr_fn, #batch的样本打包成一个tensor的结构  
                          pin_memory = True)                     #将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu

未完待续!!!

猜你喜欢

转载自blog.csdn.net/weixin_42715977/article/details/129924735