PyTorch学习笔记(8)transforms(2)

数据增强

数据增强又称为数据增光,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力

transforms Crop

transforms.CenterCrop

功能:从图像中心裁剪图片
size 所需裁剪图片尺寸

transforms.RandomCrop

功能 从图片中随机裁剪出尺寸为size的图片 随机 是指位置上的随机
size 所需裁剪图片尺寸
padding 设置填充大小
(1)当为a时 上下左右填充a个像素
(2)当为(a,b)时,上下填充b个像素,左右填充a个像素
(3) 当为(a,b,c,d)时,左,上,右,下,分别填充a,b,c,d个像素
pad_if_need 若图像小于设定size 则填充

padding_mode 填充模式 有四种模式

  1. constant 像素值由fill设定
  2. edge 像素值由图像边缘像素决定
  3. reflect 镜像填充,最后一个像素不镜像 对边缘进行2个长度的填充
    eg.[1,2,3,4] --> [3,2,1,2,3,4,3,2] [3,2|,1,2,3,4|,3,2] 1是边缘像素 不镜像 4是边缘像素 不镜像
  4. symmetric 镜像填充,最后一个像素镜像 eg.[1,2,3,4] --> [2,1,1,2,3,4,4,3] [2,1,|1,2,3,4,|4,3]

RandomResizedCrop

功能:随机大小、长宽比裁剪图片
size 所需裁剪图片尺寸
scale 随机裁剪面积比例 默认(0.08,1)
ratio 随机长宽比,默认(3/4,4/3)
interpolation 插值方法
PIL.Image.NEAREST
PIL.Image.BILINEAR
PIL.Image.BICUBIC

FiveCrop

TenCrop

功能 在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对着5张图片进行水平或者垂直镜像获得10张图片
size 所需裁剪图片尺寸
vertical_flip 是否垂直翻转

transforms – Flip

1.RandomHorizontalFlip
2.RandomVerticalFlip
功能 依概率水平(左右)或 垂直(上下)翻转图片
p 翻转概率

RandomRotation

功能 随机旋转图片
degrees 旋转角度
(1)当为a时,在(-a,a)之间选择旋转角度
(2)当为(a,b)时,在(a,b)之间选择旋转角度
resample 重采样方法
expand 是否扩大图片,以保存原图信息
center 旋转点设置,默认中心旋转

# -*- coding: utf-8 -*-

import os
import numpy as np
import torch
import random
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tools.my_dataset import RMBDataset
from PIL import Image
from matplotlib import pyplot as plt


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed(1)  # 设置随机种子

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}


def transform_invert(img_, transform_train):
    """
    将data 进行反transfrom操作
    :param img_: tensor
    :param transform_train: torchvision.transforms
    :return: PIL image
    """
    # 对 Normalize 进行反操作
    if 'Normalize' in str(transform_train):
        norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
        mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
        std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
        img_.mul_(std[:, None, None]).add_(mean[:, None, None])
    # 对通道进行变换
    img_ = img_.transpose(0, 2).transpose(0, 1)  # C*H*W --> H*W*C
    # 将 0 - 1 上的数据*255
    img_ = np.array(img_) * 255

    # 根据 C 将 img 转换成 RGB 图像 或 灰度图像
    if img_.shape[2] == 3:
        img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
    elif img_.shape[2] == 1:
        img_ = Image.fromarray(img_.astype('uint8').squeeze())
    else:
        raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]) )

    return img_

# ============================ step 1/5 数据 ============================
split_dir = os.path.join("data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    # 先将图片统一到224*224
    transforms.Resize((224, 224)),

    # 1 CenterCrop
    # 使用 CenterCrop 对图像进行裁剪 裁剪196的图片
    # transforms.CenterCrop(196),     # 512

    # 2 RandomCrop
    transforms.RandomCrop(224, padding=16),
    # transforms.RandomCrop(224, padding=(16, 64)),
    # fill 中三个元素 分别对应 RGB 三个通道
    # transforms.RandomCrop(224, padding=16, fill=(255, 0, 0)),  #将红色通道设置成255  其他两个通道设置成0
    #当RandomCrop中的参数 大于图像的参数时 pad_if_need 要设置成True
    # transforms.RandomCrop(512, pad_if_needed=True),   # pad_if_needed=True
    # transforms.RandomCrop(224, padding=64, padding_mode='edge'),
    # transforms.RandomCrop(224, padding=64, padding_mode='reflect'),
    # transforms.RandomCrop(1024, padding=1024, padding_mode='symmetric'),

    # 3 RandomResizedCrop
    # transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),

    # 4 FiveCrop
    # transforms.FiveCrop(112),
    # transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

    # 5 TenCrop
    # transforms.TenCrop(112, vertical_flip=False),
    # transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

    # 1 Horizontal Flip
    # transforms.RandomHorizontalFlip(p=1),

    # 2 Vertical Flip
    # transforms.RandomVerticalFlip(p=0.5),

    # 3 RandomRotation
    # transforms.RandomRotation(90),
    # transforms.RandomRotation((90), expand=True),
    # transforms.RandomRotation(30, center=(0, 0)),
    # transforms.RandomRotation(30, center=(0, 0), expand=True),   # expand only for center rotation

    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)


# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
    for i, data in enumerate(train_loader):

        inputs, labels = data   # B C H W

        img_tensor = inputs[0, ...]     # C H W
        # inputs.shape
        # Out[2]: torch.Size([1, 3, 196, 196])
        # 1 是 batchsize 的大小
        # 3 是 通道 由于是RGB图像 所以是3通道
        # 196 196  图像的高和宽

        # transform_invert 对transforms 进行逆操作 可以观察到模型输入的数据是什么样的
        img = transform_invert(img_tensor, train_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()

        # bs, ncrops, c, h, w = inputs.shape
        # for n in range(ncrops):
        #     img_tensor = inputs[0, n, ...]  # C H W
        #     img = transform_invert(img_tensor, train_transform)
        #     plt.imshow(img)
        #     plt.show()
        #     plt.pause(1)

发布了21 篇原创文章 · 获赞 0 · 访问量 230

猜你喜欢

转载自blog.csdn.net/qq_33357094/article/details/104444885