DataWhale 零基础入门语义分割-地表建筑物识别-Task2
文章目录
数据扩增
对语义分割任务中常见的数据扩增方法进行介绍,并使用OpenCV 和albumentations 两个库完成具体的数据扩增操作。
主要内容为数据扩增方法、OpenCV 数据扩增、albumentations 数据扩增和Pytorch 读取赛题数据四个部分组成。
一、数据扩增方法
数据扩增是一种有效的正则化方法,可以防止模型过拟合,在深度学习模型的训练过程中应用广泛。
数据扩增的目的是增加数据集中样本的数据量,同时也可以有效增加样本的语义空间。
需注意:
- 不同的数据,拥有不同的数据扩增方法;
- 数据扩增方法需要考虑合理性,不要随意使用;
- 数据扩增方法需要与具体任何相结合,同时要考虑到标签的变化;
对于图像分类,数据扩增方法可以分为两类:
- 标签不变的数据扩增方法:数据变换之后图像类别不变;
- 标签变化的数据扩增方法:数据变换之后图像类别变化;
而对于语义分割而言,常规的数据扩增方法都会改变图像的标签。如水平翻转、垂直翻转、旋转90
二、OpenCV 数据扩增
1.读取原始数据
代码如下:
img = cv2.imread(train_mask[’name’].iloc[0])
mask = rle_decode(train_mask[’mask’].iloc[0])
2.使用OpenCv进行翻转
代码如下:
'''
filename 是文件名称;filecode是进行的操作
filecode = 1 水平翻转
filecode = 0 垂直翻转
filecode = -1 水平垂直翻转
'''
cv2.flip(filename, filecode)
3.对图像进行随机裁剪
代码如下(以256*256为例):
x, y = np.random.randint(0, 256), np.random.randint(0, 256)
img = img[x:x+256, y:y+256]
三、albumentations 数据扩增
albumentations 是基于OpenCV 的快速训练数据增强库,拥有非常简单且强大的可以用于多种任务(分割、检测)的接口,易于定制且添加其他框架非常方便。
albumentations 它可以对数据集进行逐像素的转换,如模糊、下采样、高斯造点、高斯模糊、动态模糊、RGB 转换、随机雾化等;也可以进行空间转换(同时也会对目标进行转换),如裁剪、翻转、随机裁剪等。
具体介绍!!!(转载)
代码如下:
import albumentations as A
# 水平翻转:p表示的是概率
augments = A.HorizontalFlip(p=1)(image=img, mask=mask)
img_aug, mask_aug = augments[’image’], augments[’mask’]
# 随机裁剪
augments = A.RandomCrop(p=1, height=256, width=256)(image=img, mask=mask)
img_aug, mask_aug = augments[’image’], augments[’mask’]
# 旋转
augments = A.ShiftScaleRotate(p=1)(image=img, mask=mask)
img_aug, mask_aug = augments[’image’], augments[’mask’]
# 组合变换
trfm = A.Compose([
A.Resize(256, 256),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
# 随机旋转90度
A.RandomRotate90(),
])
四、pytorch数据读取
1.定义Dataset:
代码如下:
import torch.utils.data as D
class TianChiDataset(D.Dataset):
def __init__(self, paths, rles, transform):
self.paths = paths
self.rles = rles
self.transform = transform
self.len = len(paths)
def __getitem__(self, index):
img = cv2.imread(self.paths[index])
mask = rle_decode(self.rles[index])
augments = self.transform(image=img, mask=mask)
return self.as_tensor(augments['image']), augments['mask'][None]
def __len__(self):
return self.len
2.实例化Dataset对象:
trfm = A.Compose([
A.Resize(IMAGE_SIZE, IMAGE_SIZE),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(),
])
dataset = TianChiDataset(
train_mask['name'].values,
train_mask['mask'].fillna('').values,
trfm
)