详细的Object_detection api调用步骤,扎实踩过tensorflow应用时的那些坑

       Tensorflow为目标检测提供了很好的api,但其调用过程却涉及到环境设置、数据准备等各方面,比较繁杂。本文在ubuntu系统,以tensorflow 1.7为环境,逐步介绍api调用过程的每一个细节,如果小白刚接触这一块,可以按本文逐步实施,训练自己的数据,并得到网络和结果,开始object detection的第一步。(感谢魅哥的友情指导。文章中也引用了其他大侠提供的程序,在此一并表示感谢。)

1.从GitHub上下载Tensorflow的Object Detection API,地址是(http://github.com/tensorflow/model),可以通过git下载(命令 git clone http://github.com/tensorflow/model.git),也可以通过Tortoise来下载。需要注意的是,Tensorflow目前在GitHub上就一个链接,需要到官网去下载。最好翻墙出去,不然速度非常慢。(也可以到收集的程序目录中去直接拷贝)

2.下载完毕后,会有一个models文件夹,下面有一个research文件夹。之后的操作以research为根目录,设置的都是相对路径。

3.下载protoc的2.6以上版本,下载地址 http://github.com/google/protobuf/release,注意下载版本要和系统信息匹配。

4.使用protoc对proto文件进行编译,目的是把research/object_detection/protos目录下的.proto文件编译成.py文件。在系统CMD下执行命令 protoc object_detection/protos/*.proto --python_out=. 

运行完成后,在research/object_detection/protos目录下,每一个 .proto文件都会生成对应的.py文件。

5.将Slim文件夹添加到系统的系统的PYTHONPATH环境变量当中。输入命令vim ~/.bashrc 进入文件进行编辑(按i键进入修改模式),输入命令 export PYTHONPYTH='/home/pc/Deep-Learning-21-Examples-master/chapter_5/research/slim',当然,绝对路径自行修改(Esc退出文件, :wq保存)。之后再 source ~/.bashrc 保存设置。可以通过 echo $PYTHONPATH 查看环境变量。

添加环境变量之后,可以通过  python object_detection/builders/model_builder_test.py 进行分测试。

6. 开始训练的模型,首先是数据准备

(1)将新的数据集按照 Pascal VOC的格式,拷到object_detection目录下,以ROS1文件夹为例。

数据集文件夹的内部结构为

数据集放在VOCdevkit文件夹下,RSDS2016文件夹的名称可以改,但必须要和.xml文件中的描述一致,不然会报错。

    

另外在/research/object_detection的create_pascal_tf_record.py文件中,也需要对years参数进行调整

 

RSDS2016/ 目录下包含三个文件夹,其中JPEGImages/ 是数据集的图片,Annotations/ 是数据集的标签(xml文件),ImageSets/ 目录下有一个Main文件夹,再往下是四个文件夹,包含了训练数据和测试数据的分类信息,也是程序运行时首先读取的数据。

(2)这四个文件在数据集中通常是没有的,需要自己来生成。有file_text.py文件,需要设置其中的绝对路径和比例。该文件放在object_detection目录下。

# -*- coding:utf-8 -*-

'''

    该代码是将数据转为VOC2007,ImageSets里所有文件

'''



import os

__author__ ='chendingxin'

_IMAGE_SETS_PATH= '/home/pc/abc/research/object_detection/RED/VOCdevkit/RSDS2016/ImageSets'

_MAin_PATH ='/home/pc/abc/research/object_detection/RED/VOCdevkit/RSDS2016/ImageSets/Main'

_XML_FILE_PATH= '/home/pc/abc/research/object_detection/RED/VOCdevkit/RSDS2016/Annotations'



if __name__ == '__main__':

    if os.path.exists(_IMAGE_SETS_PATH):

        print('ImageSets dir is already exists')

    if os.path.exists(_MAin_PATH):

        print('Main dir is already in ImageSets')

    else:

        os.mkdir(_IMAGE_SETS_PATH)

        os.mkdir(_MAin_PATH)

    print(_MAin_PATH)

    

    # 测试集

    f_test =open(os.path.join(_MAin_PATH,'test.txt'),'w')



    # 训练和验证集

    f_trainval =open(os.path.join(_MAin_PATH,'trainval.txt'),'w')



    # trainval中训练部分

    f_train =open(os.path.join(_MAin_PATH,'train.txt'),'w')



    # trainval中验证集

    f_val =open(os.path.join(_MAin_PATH,'val.txt'),'w')



    # 遍历XML文件夹

    for root, dirs, files in os.walk(_XML_FILE_PATH):

        i =1

        j =1

        for file in files:

            if not(i % 5):  # 作为测试集,设置比例

                f_test.writelines(str(file).split('.')[0] + '\n')

            else:   # 训练和验证集

                f_trainval.writelines(str(file).split('.')[0]+'\n')

                if j % 2:  # 训练集,设置比例

                    f_train.writelines(str(file).split('.')[0]+'\n')

                else:

                    # 验证集,设置比例

                    f_val.writelines(str(file).split('.')[0]+'\n')

                j +=1

            i +=1

        f_test.close()

        f_train.close()

        f_trainval.close()

        f_val.close()

(3)部分.xml文件含有描述,如下

需要将对.xml进行修改,去掉这一行,程序放在/home/pc/abc/research/object_detection/ROS1/1/ 目录下

#coding=utf-8
import os
import os.path
import xml.dom.minidom
 
path="/home/pc/abc/ROS/Annotation"
files=os.listdir(path)
s=[]

def file_extension(path): 
    return os.path.splitext(path)[1] 

for xmlFile in files: 
    if not os.path.isdir(xmlFile): 
        if file_extension(xmlFile) == '.xml':
            print(xmlFile)
            with open(xmlFile,"r") as f:
                lines = f.readlines()
            with open(xmlFile,"w") as f_w:
                for line in lines:
                    if "<?xml" in line:
                        continue
                    f_w.write(line)

(4)在 /object_detection/ROS1目录下的pascal_label_map.pbtxt文件,需要根据分类数目和名称进行调整。

(5)Tensorflow只能读取.tfrecord格式的文档,因此,需要将以上pascal VOC格式的各类文件打包成相应格式,再输入网络。用到的是~/research/object_detection/目录下的 create_pascal_tf_record.py,注意之前提到的,对其中的years一项(2处)进行修改。另外,为了防止部分数据集打标签时不规范,出现了边界溢出的现象而报错,需要对目标标签进行容错。整体程序调整如下

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Convert raw PASCAL dataset to TFRecord for object_detection.

Example usage:
    ./create_pascal_tf_record --data_dir=/home/user/VOCdevkit \
        --year=VOC2012 \
        --output_path=/home/user/pascal.record
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import io
import logging
import os

from lxml import etree
import PIL.Image
import tensorflow as tf

from object_detection.utils import dataset_util
from object_detection.utils import label_map_util


flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
                    'merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations',
                    '(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'ROS1/pascal_label_map.pbtxt',
                    'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
                     'difficult instances')
FLAGS = flags.FLAGS

SETS = ['train', 'val', 'trainval', 'test']
YEARS = ['VOC2007', 'VOC2012', 'merged', 'RSDS2016']


def dict_to_tf_example(data,
                       dataset_directory,
                       label_map_dict,
                       ignore_difficult_instances=False,
                       image_subdirectory='JPEGImages'):
  """Convert XML derived dict to tf.Example proto.

  Notice that this function normalizes the bounding box coordinates provided
  by the raw data.

  Args:
    data: dict holding PASCAL XML fields for a single image (obtained by
      running dataset_util.recursive_parse_xml_to_dict)
    dataset_directory: Path to root directory holding PASCAL dataset
    label_map_dict: A map from string label names to integers ids.
    ignore_difficult_instances: Whether to skip difficult instances in the
      dataset  (default: False).
    image_subdirectory: String specifying subdirectory within the
      PASCAL dataset directory holding the actual image data.

  Returns:
    example: The converted tf.Example.

  Raises:
    ValueError: if the image pointed to by data['filename'] is not a valid JPEG
  """
  img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
  full_path = os.path.join(dataset_directory, img_path)
  with tf.gfile.GFile(full_path, 'rb') as fid:
    encoded_jpg = fid.read()
  encoded_jpg_io = io.BytesIO(encoded_jpg)
  image = PIL.Image.open(encoded_jpg_io)
  if image.format != 'JPEG':
    raise ValueError('Image format not JPEG')
  key = hashlib.sha256(encoded_jpg).hexdigest()

  width = int(data['size']['width'])
  height = int(data['size']['height'])

  xmin = []
  ymin = []
  xmax = []
  ymax = []
  classes = []
  classes_text = []
  truncated = []
  poses = []
  difficult_obj = []
  for obj in data['object']:
    difficult = bool(int(obj['difficult']))
    if ignore_difficult_instances and difficult:
      continue

    difficult_obj.append(int(difficult))

    xmin.append((float(obj['bndbox']['xmin']) / width>0) * (float(obj['bndbox']['xmin']) / width))
    ymin.append((float(obj['bndbox']['ymin']) / height>0) * (float(obj['bndbox']['ymin']) / height))
    xmax.append((float(obj['bndbox']['xmax']) / width<1) * (float(obj['bndbox']['xmax']) / width) + (float(obj['bndbox']['xmax']) / width>=1) )
    ymax.append((float(obj['bndbox']['ymax']) / height<1) * (float(obj['bndbox']['ymax']) / height) + (float(obj['bndbox']['ymax']) / height>=1))
    classes_text.append(obj['name'].encode('utf8'))
    classes.append(label_map_dict[obj['name']])
    truncated.append(int(obj['truncated']))
    poses.append(obj['pose'].encode('utf8'))

  example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
      'image/object/truncated': dataset_util.int64_list_feature(truncated),
      'image/object/view': dataset_util.bytes_list_feature(poses),
  }))
  return example


def main(_):
  if FLAGS.set not in SETS:
    raise ValueError('set must be in : {}'.format(SETS))
  if FLAGS.year not in YEARS:
    raise ValueError('year must be in : {}'.format(YEARS))

  data_dir = FLAGS.data_dir
  years = ['VOC2007', 'VOC2012', 'RSDS2016']
  if FLAGS.year != 'merged':
    years = [FLAGS.year]

  writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

  label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)

  for year in years:
    logging.info('Reading from PASCAL %s dataset.', year)
    examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
                                 FLAGS.set + '.txt')
    annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)
    examples_list = dataset_util.read_examples_list(examples_path)
    for idx, example in enumerate(examples_list):
      if idx % 100 == 0:
        logging.info('On image %d of %d', idx, len(examples_list))
      path = os.path.join(annotations_dir, example + '.xml')
      with tf.gfile.GFile(path, 'r') as fid:
        xml_str = fid.read()
      xml = etree.fromstring(xml_str)
      data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']

      tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
                                      FLAGS.ignore_difficult_instances)
      writer.write(tf_example.SerializeToString())

  writer.close()


if __name__ == '__main__':
  tf.app.run()

(6)在Terminal中运行2段程序

python3 create_pascal_tf_record1.py --data_dir ROS1/VOCdevkit/ --year=RSDS2016 --set=train --output_path=ROS1/pascal_train.record

python3 create_pascal_tf_record1.py --data_dir ROS1/VOCdevkit/ --year=RSDS2016 --set=val --output_path=ROS1/pascal_val.record

在~/research/object_detection/ROS1/目录下生成 pascal_train.record 和 pascal_val.record 两个文件。

(7)至此,数据准备完毕。

6. 下载现有的模型。此处采用的COCO数据集上训练的 Faster R-CNN + Inception_Resnet_V2. 下载地址为 http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017.tar.gz

下好,解压之后产生5个文件,放在 ~/research/object_detection/ROS1/pretrained/ 目录下

7. 在 ~/research/object_detection/samples/configs/ 目录下,找到模型对应的 .config 文件,将其拷到~/research/object_detection/ROS1/ 目录下,命名为 ROS1.config, 并进行修改

(1)num_classes 分类数由需要训练的数据集决定;

(2)eval_config 中的 num_examples,为验证数据集的大小,与val.txt中的图象个数(行数)相同;

(3)对5处路径进行修改,将“ROS”直接替换;

(4)dropout_keep_probability,防止过拟合;

(5)num_steps总训练步数;

(6)learning_rate设置学习率;

(7)min_dimension、max_dimension 预处理,图像缩放后的最小值和最大值(内存受限时,注意设置)。

8.将 ~/research/object_detection/ROS1/train_dir/ 目录清空,用于存放训练数据。

9.开始训练。在Terminal窗口输入命令

python3 train.py --train_dir ROS1/train_dir/ --pipeline_config_path ROS1/ROS1.config

即开始训练过程

10. 导出模型。训练模型保存在 ~/research/object_detection/ROS1/train_dir/ 中,

在Terminal中输入

python3 export_inference_graph.py --input_type image_tensor --pipeline_config_path ROS1/ROS1.config --trained_checkpoint_prefix ROS1/train_dir/model.ckpt-327 --output_directory ROS1/export/

通过运行export_inference_graph.py文件,将相应的数据导出为对应的模型,保存在 ~/research/object_detection/ROS1/export/ 目录下的 frozen_inference_graph.pb。

目录下的其他文件可以去掉。

11. 准备运行训练的模型。指定路径到 ~/research/object_detection/,进入jupyter notebook, 打开retain_object_detection_tutorial .ipynb文件

修改其中的路径,

将下载模型的相关内容添加注释

在 ~/research/object_detection/test_images/ 目录下添加待检测图片,并按 "image0.jpg"的形式修改名称。

  

12. 在jupyter notebook 中运行新模型进行检测,结果实时显示并保存在 ~/research/object_detection/test_images/ 目录,格式 .png

猜你喜欢

转载自blog.csdn.net/weixin_39153202/article/details/81631417