比赛链接:零基础入门 CV 赛事 - 街景字符编码识别
打卡任务:数据读取与数据扩增
数据读取
调用PIL库读取数据
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 设置最长的字符长度为5个
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
数据分析
利用json文件中的标签信息,将ground truth还原,以便进一步分析数据
def plot_rectangle(root_path, img_list, type):
# 新建文件夹,将画出GT的图像另存
if not os.path.exists("data_bb"):
os.mkdir("data_bb")
if not os.path.exists("data_bb/train_bb") and type == "train":
os.mkdir("data_bb/train_bb")
save_path = "data_bb/train_bb"
if not os.path.exists("data_bb/val_bb") and type == "val":
os.mkdir("data_bb/val_bb")
save_path = "data_bb/val_bb"
for i in img_list:
img_path = join(root_path, i[0])
img = cv2.imread(img_path)
for j in range(len(i[1][0])):
left = i[1]["left"][j]
top = i[1]["top"][j]
height = i[1]["height"][j]
width = i[1]["width"][j]
label = i[1]["label"][j]
# print(left, top, left+width, top+height)
cv2.rectangle(img, (int(left), int(top)), (int(left+width), int(top+height)), (0, 0, 255), 1)
cv2.imwrite(join(save_path, i[0]), img)
实现结果:
浏览整个数据集,初步分析发现:
- 边界框均为AABB型;
- 需要识别的字符几乎都位于图像中间位置;
- 字符数量多数为2到3个(最少为1个,最多为6个);
- 存在边框漏标现象:
漏标的字符会对模型训练产生干扰; - 图像长宽比较统一,几乎均为长边矩形;
- 图像大小极不均匀
小的边才10个像素左右,大的边达到将近900个像素。但从图中也可看出,长宽比较均匀,均为长边矩形;
数据增强
transforms.CenterCrop(crop_size)
中心裁剪
transforms.RandomRotation(angle)
随机选择角度
transforms.Grayscale(num_output_channels=3)
灰度变换
transforms.RandomPerspective(distortion_scale=0.5, p=1, interpolation=2)
图像扭曲
其他可供尝试的方法:mixup
、cutmix
、blur
等等。