深度之眼Pytorch打卡(十):Pytorch数据预处理——数据统一与数据增强(上)

前言


  本笔记续上一篇笔记,更加深入的学习pytorch的各种数据预处理方法,包括数据标准化、尺寸调整、各种裁剪方法以及结果的可视化。本笔记的知识框架主要来源于深度之眼,并作了一些相关的拓展,拓展内容主要源自对torch文档的翻译理解,所用数据来源于网络。

  数据读取参考:深度之眼Pytorch打卡(七):Pytorch数据读取机制,DataLoader()和Dataset
  数据标准化参考:深度之眼Pytorch打卡(九):数据预处理与数据标准化(Normalize原理、常用数据集均值标准差与数据集均值标准差计算)


迭代器


  迭代器是访问集合元素的一种方式。迭代器对象从集合的第一个元素开始访问,直到所有的元素均被访问完结束,只能往前不能后退。此处要使用迭代的方式,通过Dataloader来读取数据。

  iter(): 将可迭代的对象转换成迭代器对象,此处即将Dataloader转换成迭代器对象,便于手动迭代来取一个batch的数据。

data_iter = iter(Dataloader)

  next(): 返回迭代器的下一个元素。此处即一个batch的数据和标签,0为数据,1为标签。

data = data_iter.next()[0]

数据统一


  • transform.Resize()
torchvision.transforms.Resize(size, interpolation=2)

  size: 缩放到的尺寸,整数或者序列(h,w)
  interpolation: 插值方法,缩放图片要进行插值,默认双线性插值。考虑到后面要用AlexNet来做口罩识别,其输入为224×224,所以这里将尺寸调为此值。统一样本尺寸后,才能让batch_size大于1,否则会因为维度不一致而使得数据打包失败,所以这是必做操作。interpolation=1:PIL.Image.NEAREST,interpolation=2:PIL.Image.BILINEAR,interpolation=3:PIL.Image.BICUBIC

re_img = transforms.Resize((224, 224))(pil_img)
  • transform.ToTensor()

  将PIL图像或者numpy数组转换成tensor,并且将值归一化到[0.0,1.0]之间。必做操作。

img = transforms.ToTensor()(re_img)
  • transform.Normalize()

  直接使用上一篇笔记中计算出来的均值与标准差,mean = [0.581, 0.535, 0.514],std = [0.299, 0.299, 0.304]带入进行数据标准化。

import os
from tools.dataload import DataSet
from tools.get_mean_std import get_mean_std
from torchvision import transforms
from torch.utils.data import DataLoader

label_name = {
    
    'masking': 1, 'unmasking': 0}
BATCH_SIZE = 1
mean = [0.581, 0.535, 0.514]
std = [0.299, 0.299, 0.304]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

if __name__ == '__main__':

    train_set_path = os.path.join('data', 'train_set')

    train_set = DataSet(data_path=train_set_path, label_name=label_name, transform=transform)
    train_mean, train_std = get_mean_std(train_set)
    print(train_mean, train_std)
[ 4.0162424e-05  6.9549726e-04 -5.4381520e-04] [1.0480237 1.0458964 1.0433147]

数据增强


  数据增强,又称为数据增广或者数据扩增,用于对训练集进行变换,使得训练集数据更加丰富,进而增加模型的泛化能力。在不改变数据主要特征的前提下,改变数据的次要特征,变着法的再让模型认,如果在这种情况下模型都可以表现很好,说明模型真正学到了决定数据本质的主要特征,而受次要特征的影响较小。那样,在实际应用中,哪怕次要特征变出花来,模型都有很大可能认对,即有较强的泛化能力。

  • 数据逆变换

  为了看到数据增强后的效果,需要显示图片。但是此时的数据是原始的pil图片经过transform后得到的张量,需要进行transform的逆操作恢复到pil图片,其主要是标准化逆过程,和张量转图片。想要进行标准化逆过程,就需要先从transform中找到Normalize,然后再从Normalize中取出出正过程的meanstd

# transform
Compose(
    Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR)
    ToTensor()
    Normalize(mean=[0.581, 0.535, 0.514], std=[0.299, 0.299, 0.304])
)

# transform.transforms
[Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR), ToTensor(), Normalize(mean=[0.581, 0.535, 0.514], std=[0.299, 0.299, 0.304])]

  filter(): 用于过滤掉列表中不符合条件的元素,返回由符合条件元素组成的新列表,其有两个参数,第一个为用于判断的函数,第二个为列表(可迭代对象)。将输入列表中的元素一个一个(迭代)带入判断函数,满足条件则放入新列表。由上可知transform.transforms是一个包含若干预处理方法组成的列表,作为第二个参数。
  lambda 表达式: 对出传入的参数的操作或判断,可定义简单函数。格式为,lambda 输入参数:对输入的判断或者操作函数。
  isinstance(): 判断一个对象是否是一个已知的类型,接收两个参数,参数一为待判断的对象,参数二为类名,基本类型,或者是它们组成的元组。两者相同则返回 True,否则返回 False。此处参数一是transform.transforms中的一个方法,参数二是transforms.Normalize

Normalize_trans = list(filter(lambda x: isinstance(x, transforms.Normalize), transform.transforms))
print(Normalize_trans)
[Normalize(mean=[0.581, 0.535, 0.514], std=[0.299, 0.299, 0.304])]

  transform_inverse.py

import torch
import numpy as np
from PIL import Image
from torchvision import transforms


def transform_inverse(img, transform):
    # 将tensor转换成pil数据
    if 'Normalize' in str(transform):
        Normalize_trans = list(filter(lambda x: isinstance(x, transforms.Normalize), transform.transforms))
        m = torch.tensor(Normalize_trans[0].mean, dtype=img.dtype, device=img.device)
        s = torch.tensor(Normalize_trans[0].std, dtype=img.dtype, device=img.device)
        img.mul_(torch.reshape(s, [-1, 1, 1])).add_(torch.reshape(m, [-1, 1, 1]))       # 需要调整形状才能通道对应
    img = torch.transpose(img, dim0=0, dim1=1)                                          # C H W ->H C W
    img = torch.transpose(img, dim0=1, dim1=2)                                          # H C W ->H W C

    img = np.array(img)*255                                                             # 去归一化
    # print(img)
    if img.shape[2] == 3:
        img = Image.fromarray(img.astype('uint8')).convert('RGB')                       # 转换成PIL RGB图像
    elif img.shape[2] == 1:
        img = Image.fromarray(img.astype('uint8').squeeze())                           # (1, H, W)->(H, W)
    else:
        print('Invalid img format')
    return img

  主函数中添加如下代码

    img = iter(train_loader).next()[0]
    img = torch.squeeze(img, dim=0)                  # (1,3,224,224)->(3,224,224)
    print(np.array(img).shape)
    pil_img = transform_inverse(img, transform)
    ax1 = plt.subplot(111)
    ax1.imshow(pil_img)
    plt.show()

  结果:
在这里插入图片描述

图1.transform逆变换结果
  • 数据增强——裁剪

  transforms.CenterCrop()

  中心裁剪,裁剪尺寸小于图像尺寸往内缩,大于往外补黑边。如图2所示。如果每次迭代取数据时裁剪位置相同,则不具有增强特点。

torchvision.transforms.CenterCrop(size)
transforms.CenterCrop(224),
transforms.CenterCrop(447),
在这里插入图片描述 在这里插入图片描述
在这里插入图片描述 在这里插入图片描述
图2.CenterCrop

  transforms.RandomCrop()

  随机裁剪,padding:裁剪前图像需要拓展边缘的宽度,padding_mode:拓展的方式,包括constant,拓展边缘填充常数值fill=0或value或(R,G,B),镜像symmetric,reflect,以边缘为轴,用镜像像素的值对拓展边缘进行填充,edge:用图像边缘像素值进行填充,不选方式不设fill默认填充常数0,即黑色。pad_if_needed:当图像小于裁剪尺寸时,为了避免出错,需要对图像进行拓展,设为True。由于裁剪具有随机性,可以增强数据。

torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')
transforms.RandomCrop(447, pad_if_needed=True),
transforms.RandomCrop(224, padding=100, pad_if_needed=True, padding_mode='reflect'),
224x224 447x447(pad_if_needed=True)
在这里插入图片描述 在这里插入图片描述

在这里插入图片描述

图3.RandomCrop

  transforms.RandomResizedCrop()

  随机大小,随机长宽比裁剪图像。scale,随机的裁剪面积与原面积之比。ratio,随机的长宽比。在图像中随机长宽比,随机面积选取区域,然后resize到设定size的图片。

torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)
transforms.RandomResizedCrop(224),

在这里插入图片描述

图4.RandomResizedCrop

  transforms.FiveCrop()

  在图像四个角与中心各裁剪一张size大小的图片,共形成五张。如下代码所示,FiveCrop返回的值是由5张图片构成的元组。转换成张量时ToTensor()的输入只能是PIL image或者ndarray。如果直接元组作为输入将会报错:pic should be PIL Image or ndarray. Got <class 'tuple'>。故需要在进行ToTensor()操作前,对FiveCrop输出做处理。

torchvision.transforms.FiveCrop(size)
def five_crop(img, size):
    if isinstance(size, numbers.Number):
        size = (int(size), int(size))
    else:
        assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

    image_width, image_height = img.size
    crop_height, crop_width = size
    if crop_width > image_width or crop_height > image_height:
        msg = "Requested crop size {} is bigger than input size {}"
        raise ValueError(msg.format(size, (image_height, image_width)))

    tl = img.crop((0, 0, crop_width, crop_height))
    tr = img.crop((image_width - crop_width, 0, image_width, crop_height))
    bl = img.crop((0, image_height - crop_height, crop_width, image_height))
    br = img.crop((image_width - crop_width, image_height - crop_height,
                   image_width, image_height))
    center = center_crop(img, (crop_height, crop_width))
    return (tl, tr, bl, br, center)

  前面提到,匿名函数lambd可以定义简单函数,对输入进行操作。transforms中的Lambda(lambd)函数可供用户自定义的lambda来做transformlambda函数的输入就是那个图像元组pics,处理函数就是将pics中的图片一张一张的取出来进行ToTensor()操作,然后拼接在一起,形成一个(1,5,3,H,W)的五维张量。

transforms.FiveCrop(160),
transforms.Lambda(lambda pics: torch.stack([(transforms.ToTensor()(pic)) for pic in pics])),

  主函数中可视化部分做如下修改:

    for i in range(len(img)):
        pil_img = transform_inverse(img[i], transform)
        ax1 = plt.subplot(2, 3, i+2)
        ax1.set_title('processed image')
        ax1.imshow(pil_img)
    plt.show()

  结果:
在这里插入图片描述

图5.FiveCrop

  transforms.TenCrop()

  FiveCrop裁剪出的5张图片,各自进行水平或者竖直方向的翻转,故形成5X2=10张图片。默认水平翻转,即vertical_flip=False

torchvision.transforms.TenCrop(size, vertical_flip=False)
    transforms.TenCrop(160, vertical_flip=True),
    transforms.Lambda(lambda pics: torch.stack([(transforms.ToTensor()(pic)) for pic in pics])),

  主函数中可视化代码:

    for i in range(len(img)):
        pil_img = transform_inverse(img[i], transform)
        ax1 = plt.subplot(3, 4, i+2)
        ax1.set_title('processed image')
        ax1.imshow(pil_img)
    plt.show()

  结果:
在这里插入图片描述

图6.TenCrop

  注:本笔记涉及代码及数据集在下一篇笔记:深度之眼Pytorch打卡(十一):Pytorch数据预处理——数据增强(下)中完整给出。


参考


  https://blog.csdn.net/baidu_36831253/article/details/78647391
  https://www.cnblogs.com/liangxiyang/p/11288181.html
  https://www.runoob.com/python/python-func-isinstance.html

猜你喜欢

转载自blog.csdn.net/sinat_35907936/article/details/107450092