create_pascal_tf_record.py 代码解析

在用tensorflow做object detection时,用create_pascal_tf_record.py将Pascal VOC 数据集转换成TFRecord
代码源地址:https://github.com/tensorflow/models/blob/master/object_detection/create_pascal_tf_record.py

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function



import hashlib

import io

import logging

import os



from lxml import etree

import PIL.Image

import tensorflow as tf



from object_detection.utils import dataset_util

from object_detection.utils import label_map_util





flags = tf.app.flags
#定义变量 flags (name变量名称,default默认值,describe变量描述)
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')

flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '

                    'merged set.')

flags.DEFINE_string('annotations_dir', 'Annotations',

                    '(Relative) path to annotations directory.')

flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')

flags.DEFINE_string('output_path', '', 'Path to output TFRecord')

flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',

                    'Path to label map proto')

flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '

                     'difficult instances')

FLAGS = flags.FLAGS



SETS = ['train', 'val', 'trainval', 'test']
#数据集包括四个部分 'train', 'val', 'trainval', 'test'
YEARS = ['VOC2007', 'VOC2012', 'merged']

#可以用2007 2012或者混合



def dict_to_tf_example(data,

                       dataset_directory,

                       label_map_dict,

                       ignore_difficult_instances=False,

                       image_subdirectory='JPEGImages'):

  """Convert XML derived dict to tf.Example proto.



  Notice that this function normalizes the bounding box coordinates provided

  by the raw data.



  Args:

    data: dict holding PASCAL XML fields for a single image (obtained by

      running dataset_util.recursive_parse_xml_to_dict)

    dataset_directory: Path to root directory holding PASCAL dataset

    label_map_dict: A map from string label names to integers ids.

    ignore_difficult_instances: Whether to skip difficult instances in the

      dataset  (default: False).

    image_subdirectory: String specifying subdirectory within the

      PASCAL dataset directory holding the actual image data.



  Returns:

    example: The converted tf.Example.



  Raises:

    ValueError: if the image pointed to by data['filename'] is not a valid JPEG

  """

  img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
  #数据集路径
  full_path = os.path.join(dataset_directory, img_path)

  with tf.gfile.GFile(full_path, 'rb') as fid:

    encoded_jpg = fid.read()

  encoded_jpg_io = io.BytesIO(encoded_jpg)

  image = PIL.Image.open(encoded_jpg_io)
    #读取图片信息
  if image.format != 'JPEG':

    raise ValueError('Image format not JPEG')

  key = hashlib.sha256(encoded_jpg).hexdigest()



  width = int(data['size']['width'])

  height = int(data['size']['height'])



  xmin = []

  ymin = []

  xmax = []

  ymax = []

  classes = []

  classes_text = []

  truncated = []

  poses = []

  difficult_obj = []

  for obj in data['object']:

    difficult = bool(int(obj['difficult']))

    if ignore_difficult_instances and difficult:

      continue



    difficult_obj.append(int(difficult))



    xmin.append(float(obj['bndbox']['xmin']) / width)

    ymin.append(float(obj['bndbox']['ymin']) / height)

    xmax.append(float(obj['bndbox']['xmax']) / width)

    ymax.append(float(obj['bndbox']['ymax']) / height)

    classes_text.append(obj['name'].encode('utf8'))

    classes.append(label_map_dict[obj['name']])

    truncated.append(int(obj['truncated']))

    poses.append(obj['pose'].encode('utf8'))


#将解析出的信息写成TFRecord
  example = tf.train.Example(features=tf.train.Features(feature={

      'image/height': dataset_util.int64_feature(height),

      'image/width': dataset_util.int64_feature(width),

      'image/filename': dataset_util.bytes_feature(

          data['filename'].encode('utf8')),

      'image/source_id': dataset_util.bytes_feature(

          data['filename'].encode('utf8')),

      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),

      'image/encoded': dataset_util.bytes_feature(encoded_jpg),

      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),

      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),

      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),

      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),

      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),

      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),

      'image/object/class/label': dataset_util.int64_list_feature(classes),

      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),

      'image/object/truncated': dataset_util.int64_list_feature(truncated),

      'image/object/view': dataset_util.bytes_list_feature(poses),

  }))

  return example


def main(_):

  if FLAGS.set not in SETS:

    raise ValueError('set must be in : {}'.format(SETS))

  if FLAGS.year not in YEARS:

    raise ValueError('year must be in : {}'.format(YEARS))



  data_dir = FLAGS.data_dir

  years = ['VOC2007', 'VOC2012']

  if FLAGS.year != 'merged':

    years = [FLAGS.year]



  writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

    #建立一个输出文件

  label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)



  for year in years:

    logging.info('Reading from PASCAL %s dataset.', year)

    examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',

                                 'aeroplane_' + FLAGS.set + '.txt')

    annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)

    examples_list = dataset_util.read_examples_list(examples_path)

    for idx, example in enumerate(examples_list):

      if idx % 100 == 0:

        logging.info('On image %d of %d', idx, len(examples_list))

      path = os.path.join(annotations_dir, example + '.xml')

      with tf.gfile.GFile(path, 'r') as fid:

        xml_str = fid.read()

      xml = etree.fromstring(xml_str)

      data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
     #从annotation中读取xml文件,即每张图片的标签信息
      tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,

                                      FLAGS.ignore_difficult_instances)
    #dict_to_tf_example()函数解析data中的数据
      writer.write(tf_example.SerializeToString())
    #将数据写入文件中


  writer.close()

if __name__ == '__main__':

  tf.app.run()

猜你喜欢

转载自blog.csdn.net/qq_32799915/article/details/77163209
今日推荐