在谷歌目标检测(Google object_detection) API 上训练自己的数据集

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/hnsywangxin/article/details/76442228

本文未经同意禁止转载,谢谢配合!

知乎链接:https://zhuanlan.zhihu.com/p/28218410

应公司要求,利用谷歌最近开源的Google object_detection API对公司收集的数据集进行训练,并检测训练效果。通过一两天的研究以及维持四天的训练(GTX 1060  6GB),终于成功的在自己数据集上训练的任务。测试效果感觉还行,虽没有达到谷歌官方公布的数据集上跑的识别效果,但是识别率也还过得去,这主要是因为数据集没有官方做的那么规范。下图为本人挑选的一张识别率较好的图片(识别哈尔滨啤酒):


下面把本人如何一步步在自己的数据集上训练的详细步骤做个总结,一是方便自己以后操作起来更快的再次上手训练,二是方便大家能好的实现该API的一些需求。

需要说明的:

1:本教程用的模型权重参数为faster_rcnn_resnet101_coco  ,可点击进行模型的下载。

2:数据集格式需要为转换成tensorflow要求的tfrecord的形式。

3:本文在GTX 1060  6GB的显卡上训练了四天

4:如何安装tensorflow等一些依赖库,本文不再赘述,请参考:安装依赖库教程链接


过程:

1:下载Google object_detection API

下载地址

2:数据集准备:

数据集需要符合API所需的TFRecord格式,官方提供的数据集格式为PASCAL VOC格式,API已经为我们提供了将此格式转为TFRecord的代码. 但是这里我们需要注意一个细节:create_pascal_tf_record.py中的

examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
                             'aeroplane_' + FLAGS.set + '.txt')
去掉'aeroplane_'。

同时,将文件中的

 
 
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
                    'Path to label map proto')
 

data/pascal_labe_map.pbtxt改为自己的数据集label

然后在“tensorflow/models/object_detection/”目录下运行以下命令

#生成训练集record
python create_pascal_tf_record.py --data_dir=`自己的训练数据集路径` \
    --year=VOC2007 --set=train --output_path=`你想保存的训练集的record路径`

#生成验证集record
 python create_pascal_tf_record.py --data_dir=`自己的验证数据集路径` \

    --year=VOC2007 --set=val --output_path=`你想保存的验证集的record路径`

注意,在data目录下选择一个.pbtxt文件,将该文件改为自己数据集的label。
执行上述两个命令后会在data文件夹下生成两个record文件。 

3:下载预训练模型

按照上文“需要说明的”第一条下载预训练模型,将下载好的模型进行解压,并将.ckpt的三个文件拷贝到models目录下。将object_detection/samples/configs/faster_rcnn_resnet101_voc07.config复制到models目录下并做如下修改:
1)num_classes:修改为之前修的的.pbtxt文件中的类别数目
2)将所有'PATH_TO_BE_CONFIGURED'修改为自己之前设置的路径

4:开始训练

执行上述三步之后我们可以开始训练了,此处需要注意两点,不然会出现模块导出错误,在tensorflow/models分别运行:


protoc object_detection/protos/*.proto --python_out=.

export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

然后进入到obeject_detection目录下,运行一下命令:


python train.py --train_dir='想要保存训练模型的路径' --pipeline_config_path='你采用的.config文件路径'

5:模型可视化

运行上述四步之后您基本上只需等着模型运行完成即可,如果您想要可视化您的模型,可以运行:

tensorboard --logdir=’上面第4点提到的train_dir路径‘

然后在你的浏览器输入0.0.0.0:6006就能看到模型一些相关参数的可视化结果了。

训练完成后会生成三个.cpkt的文件,将这三个文件复制到tensorflow/models下,可利用这三个文件生成一个.pb文件,生成代码如下:

python object_detection/export_inference_graph.py \
    --input_type image_tensor \
    --pipeline_config_path ’你的.config文件路径’ \
    --checkpoint_path model.ckpt-‘CHECKPOINT_NUMBER’ \
    --inference_graph_path output_inference_graph.pb
这样你就可以利用.pb文件进行目标检测了,具体步骤请参考: github.com/tensorflow/m

6:参考

https://zhuanlan.zhihu.com/p/27469690

https://github.com/tensorflow/models/blob/master/object_detection/g3doc/installation.md

https://github.com/tensorflow/models/blob/master/object_detection/g3doc/running_pets.md


如您觉得本文对你有帮助,请酌情赞赏。同时本文如有不完善的地方欢迎指正!谢谢!


猜你喜欢

转载自blog.csdn.net/hnsywangxin/article/details/76442228
今日推荐