Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增
学习目标
学习Python和Pytorch中图像读取
学会扩增方法和使用Pytorch读取赛题数据
1.Python中的图像读取
在python中进行图像读取的方法有多种,这里介绍两种图像读取方法
(1)在python中利用pillow库进行图像读取操作
import numpy as np
from PIL import Image
# 打开图像
im = Image.open('D:/paper/SVHN/IN/mchar_val/mchar_val/000000.png')
im.show()
运行代码后,读取到图片
(2) 在python中利用matplotlib库进行图像读取操作
import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
im = mpimg.imread('D:/paper/SVHN/IN/mchar_val/mchar_val/000000.png')
plt.imshow(im) # 显示图片
运行代码之后得到如下效果
2.数据扩增方法
在了解了图片读取的方法后,我们继续了解一下对图片进行数据扩增的方法。这一节包括对数据扩增的简单介绍,常用的数据扩增方法以及常用的数据扩增库。
在数据数量较小或数据种类比较单一时,就需要对数据进行扩增,来增加可以用来训练和学习的数据的数量和种类。
常见的数据扩增方法有裁剪、灰度变换、像素填充、随机旋转等多种方法。
transforms.CenterCrop 对图片中心进行裁剪
transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
transforms.Grayscale 对图像进行灰度变换
transforms.Pad 使用固定值进行像素填充
transforms.RandomAffine 随机仿射变换
transforms.RandomCrop 随机区域裁剪
transforms.RandomHorizontalFlip 随机水平翻转
transforms.RandomRotation 随机旋转
transforms.RandomVerticalFlip 随机垂直翻转
3.Pytorch读取数据
参考baseline中的代码
(1)首先,进行数据集的定义读取
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
(2)定义读读取数据
train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(train_label))
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.RandomCrop((60, 120)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=40,
shuffle=True,
num_workers=10,
)
val_path = glob.glob('../input/val/*.png')
val_path.sort()
val_json = json.load(open('../input/val.json'))
val_label = [val_json[x]['label'] for x in val_json]
print(len(val_path), len(val_label))
val_loader = torch.utils.data.DataLoader(
SVHNDataset(val_path, val_label,
transforms.Compose([
transforms.Resize((60, 120)),
# transforms.ColorJitter(0.3, 0.3, 0.2),
# transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=40,
shuffle=False,
num_workers=10,
)