Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增

Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增

学习目标

学习Python和Pytorch中图像读取
学会扩增方法和使用Pytorch读取赛题数据

1.Python中的图像读取

在python中进行图像读取的方法有多种,这里介绍两种图像读取方法

(1)在python中利用pillow库进行图像读取操作

import numpy as np
from PIL import Image
# 打开图像
im = Image.open('D:/paper/SVHN/IN/mchar_val/mchar_val/000000.png')
im.show()

运行代码后,读取到图片
在这里插入图片描述
(2) 在python中利用matplotlib库进行图像读取操作

import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
im = mpimg.imread('D:/paper/SVHN/IN/mchar_val/mchar_val/000000.png')
plt.imshow(im) # 显示图片

运行代码之后得到如下效果
在这里插入图片描述

2.数据扩增方法

在了解了图片读取的方法后,我们继续了解一下对图片进行数据扩增的方法。这一节包括对数据扩增的简单介绍,常用的数据扩增方法以及常用的数据扩增库。

在数据数量较小或数据种类比较单一时,就需要对数据进行扩增,来增加可以用来训练和学习的数据的数量和种类。

常见的数据扩增方法有裁剪、灰度变换、像素填充、随机旋转等多种方法。

transforms.CenterCrop 对图片中心进行裁剪
transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
transforms.Grayscale 对图像进行灰度变换
transforms.Pad 使用固定值进行像素填充
transforms.RandomAffine 随机仿射变换
transforms.RandomCrop 随机区域裁剪
transforms.RandomHorizontalFlip 随机水平翻转
transforms.RandomRotation 随机旋转
transforms.RandomVerticalFlip 随机垂直翻转

3.Pytorch读取数据

参考baseline中的代码
(1)首先,进行数据集的定义读取

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

(2)定义读读取数据

train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(train_label))

train_loader = torch.utils.data.DataLoader(
    SVHNDataset(train_path, train_label,
                transforms.Compose([
                    transforms.Resize((64, 128)),
                    transforms.RandomCrop((60, 120)),
                    transforms.ColorJitter(0.3, 0.3, 0.2),
                    transforms.RandomRotation(10),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])), 
    batch_size=40, 
    shuffle=True, 
    num_workers=10,
)

val_path = glob.glob('../input/val/*.png')
val_path.sort()
val_json = json.load(open('../input/val.json'))
val_label = [val_json[x]['label'] for x in val_json]
print(len(val_path), len(val_label))

val_loader = torch.utils.data.DataLoader(
    SVHNDataset(val_path, val_label,
                transforms.Compose([
                    transforms.Resize((60, 120)),
                    # transforms.ColorJitter(0.3, 0.3, 0.2),
                    # transforms.RandomRotation(5),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])), 
    batch_size=40, 
    shuffle=False, 
    num_workers=10,
)

猜你喜欢

转载自blog.csdn.net/m0_47024418/article/details/106321932