项目介绍
此次项目为datawhale和阿里天池合作举办的cv入门赛事街道字符识别。该任务较早见于花书的作者古德费洛在12-13年在谷歌做出的研究。当时谷歌需要对极大的街景门牌号图片数据集进行数字提取以使每一个地点能在谷歌地图上有正确的门牌号信息。这样一个转码项目必然要消耗极大的人力物力,因此当时采用深度学习模型来实现自动转码,并最终取得了98%的覆盖率,大幅提高效率。
数据说明
本次项目采用谷歌公开数据集SVHM,其中测试集数据为3万张图片,验证集数据为1万张图片。
为了降低比赛的难度,对于每一张图片,都有对应的编码标签和具体的字符框的位置以用于模型的训练,具体的数据格式说明如下表所示:
Field | Description |
---|---|
top | 字符框左上角的坐标X |
height | 字符高度 |
left | 字符框左上角的坐标Y |
width | 字符框的宽度 |
label | 字符编码 |
图示如下:
因为一张图片可能包含一个或多个字符,因此在比赛数据集当中的JSON标注中,会有两个边框信息。示例图片的JSON标注如下图所示:
评价指标
项目的评测指标以字符串整体的识别率作为标准,其中任何一个字符的错误都算整体错误。score表示如下:
数据读取
此为JSON标签的读取方式:
import json
train_json = json.load(open('../input/train.json')) #改变../input路径读取文件 下同
def parse_json(d):
arr = np.array([
d['top'], d['height'], d['left'], d['width'], d['label']
])
arr = arr.astype(int)
return arr
img = cv2.imread('../input/train/000000.png') #读取图像
arr = parse_json(train_json['000000.png'])
plt.figure(figsize=(10, 10)) #以下为绘图部分
plt.subplot(1, arr.shape[1]+1, 1)
plt.imshow(img)
plt.xticks([]); plt.yticks([])
for idx in range(arr.shape[1]):#将图片当中包含字符的部分完全显现出来
plt.subplot(1, arr.shape[1]+1, idx+2)
plt.imshow(img[arr[0, idx]:arr[0, idx]+arr[1, idx],arr[2, idx]:arr[2, idx]+arr[3, idx]])
plt.title(arr[4, idx])
plt.xticks([]); plt.yticks([])
运行之后的效果如下图所示:
解题思路
本次比赛主要为提取图片当中的字符信息,其中的难点在于图片当中字符为不定长,因此可以下主要从以下三个角度对本赛题进行思考:
1:将该不定长字符识别转化为定长字符识别,也即是官方baseline给出的基本思路。在选用resnet18(官方baseline)或resnet50模型架构时,目前笔者得到的最好结果为60.36%的识别正确率以及最低2.515的val_loss(采用交叉熵损失函数)。
2:直接从不定长字符识别的角度进行出发,如CRNN字符识别模型。
3:目标检测加数字识别。本次赛题数据已经给出了每张图片当中对应数字的位置,因此可以采用先分割成每一部分,再识别对应部分的数字,最终进行组合的思路实现。此处的目标检测可用SSD,YOLO v3,同时可以考虑使用最新的YOLO v4当中的一些trick提高准确率
写在最后
本次为笔者参加的第一次项目,以这次项目作为载体,笔者将不定期得写出相关内容知识的学习总结博客。文稿不足之处望大家理解海涵,并且欢迎指正批评,希望各位看后能够有所收获,对自己在深度学习实践方面有所帮助。