【datawhale】学习小组打卡博客2

比赛链接:零基础入门 CV 赛事 - 街景字符编码识别
打卡任务:数据读取与数据扩增

数据读取

调用PIL库读取数据

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 设置最长的字符长度为5个
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

数据分析

利用json文件中的标签信息,将ground truth还原,以便进一步分析数据

def plot_rectangle(root_path, img_list, type):
	# 新建文件夹,将画出GT的图像另存
    if not os.path.exists("data_bb"):
        os.mkdir("data_bb")
    if not os.path.exists("data_bb/train_bb") and type == "train":
        os.mkdir("data_bb/train_bb")
        save_path = "data_bb/train_bb"
    if not os.path.exists("data_bb/val_bb") and type == "val":
        os.mkdir("data_bb/val_bb")
        save_path = "data_bb/val_bb"

    for i in img_list:
        img_path = join(root_path, i[0])
        img = cv2.imread(img_path)
        for j in range(len(i[1][0])):
            left = i[1]["left"][j]
            top = i[1]["top"][j]
            height = i[1]["height"][j]
            width = i[1]["width"][j]
            label = i[1]["label"][j]
            # print(left, top, left+width, top+height)
            cv2.rectangle(img, (int(left), int(top)), (int(left+width), int(top+height)), (0, 0, 255), 1)
        cv2.imwrite(join(save_path, i[0]), img)

实现结果:
在这里插入图片描述
浏览整个数据集,初步分析发现:

  1. 边界框均为AABB型;
  2. 需要识别的字符几乎都位于图像中间位置;
  3. 字符数量多数为2到3个(最少为1个,最多为6个);
  4. 存在边框漏标现象:
    在这里插入图片描述
    漏标的字符会对模型训练产生干扰;
  5. 图像长宽比较统一,几乎均为长边矩形;
  6. 图像大小极不均匀
    train_img_size
    小的边才10个像素左右,大的边达到将近900个像素。但从图中也可看出,长宽比较均匀,均为长边矩形;

数据增强

  1. transforms.CenterCrop(crop_size)中心裁剪
    在这里插入图片描述
  2. transforms.RandomRotation(angle)随机选择角度
    在这里插入图片描述
  3. transforms.Grayscale(num_output_channels=3)灰度变换
    在这里插入图片描述
  4. transforms.RandomPerspective(distortion_scale=0.5, p=1, interpolation=2)图像扭曲
    在这里插入图片描述
    其他可供尝试的方法:mixupcutmixblur等等。

猜你喜欢

转载自blog.csdn.net/weixin_45612763/article/details/106238725