pytorch之数据处理

在PyTorch中,数据加载可通过自定义的数据集对象实现。数据集对象被抽取为DataSet类,实现自定义的数据集需要集成DataSet,并实现两个方法。

__getitem__ : 返回一条数据或一个样本。

__len__ : 返回样本的数量。

有时候数据是图片,图片的大小形状不一,返回的样本数值归一化至[-1,1]。torchvision提供了很多视觉图像处理的工具,其中transform模块提供了对PIL Image对象和Tensor对象的常用操作。

对 PIL Image的常见操作如下:

Resize : 调整图片尺寸

CenterCrop、RandomCrop、RandomSizedCrop:剪裁图片

pad:填充

ToTensor : 将PIL Image对象转成Tensor,会自动将[0,255]归一化至[0,1]。

对Tensor的常见操作如下:

Normalize : 标准化。即减去均值除以标准差

ToPILImage:将Tensor转为PIL Image对象。

如果要对图片进行多个操作,可通过Compose将这些操作拼接起来,类似于nn.Sequential。

transform = transforms.Compose(
                [transforms.Resize(224), # 缩放图片,保持长宽比不变
                 transforms.CenterCrop(224), # 从图片中间切出224 * 224的图片
                 transforms.Totensor(), # 将图片转换成Tensor,归一化至[0,1]
                ])
class MyData(data.DataSet):
    def __init__(self, root, transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms = transforms
    
    def __getitem__(self, index):
        img_path = self.imgs[index]
        label = 0 if 'dog' in img_path.split('/')[-1] else 1
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)

torchvision已经预先实现了常用的Dataset,包括CIFAR-10,ImageNet,COCO,MNIST,LSUN等数据集,可通过调用相应的对象来调用相关数据集。

下面介绍DataSet——ImageFolder,ImageFolder假设所有的文件按文件夹保存,每个文件下存储同一个类别的图片,文件夹名为类名,其构造函数如下:

ImageFolder(root,transform=None,traget_transform=None,loader=default_loader)

它主要有以下四个参数:

root : 在root指定的路径下寻找图片。

transform : 对PIL Image进行转换操作,transform的输入是使用loader读取图片的返回对象。

target_transform : 对label的转换。

loader:指定加载图片的函数,默认操作时读取为PIL Image对象。

label是按照文件夹名顺序排序后存成字典的,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一致。

Dataset只负责数据的抽象,一次调用__getitem__只返回一个样本。前面提到过,在训练神经网络时,时对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,Pytorch提供了DataLoader帮助我们实现这些功能。

DataLoader(dataset, batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False)

datadset:加载的数据集

batch_size:批大小

shuffle:是否将数据打乱

sampler:样本抽样

num_workers:使用多进程加载的进程数,0代表不使用多进程。

collate_fn : 如何将多个人样本数据拼接成一个batch,一般使用默认的拼接方式即可。

pin_memory:是否将数据保存在pin memory区,pin memory中的数据转换到GPU会快一些。

drop_last : dataset中的数据个数可能不是batch_size的整数倍。

dataloader是一个可迭代的对象,可以像使用迭代器一样使用它。

发布了16 篇原创文章 · 获赞 3 · 访问量 719

猜你喜欢

转载自blog.csdn.net/FeNGQiHuALOVE/article/details/104505202