根据自己的数据集制作tf.records(图片信息保存非xml格式读取)

convert.py转换信息的主文件
定义好datadir即图片文件夹所在的地址
输出tfrecords文件的输出路径
还有数据集的名称

import tensorflow as tf
import pascalvoc_to_tfrecords
FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string(
    'dataset_name', 'TsingHua_Tencent',
    'The name of the dataset to convert.')
# 数据集的路径
tf.app.flags.DEFINE_string(
    'dataset_dir', r'/mnt/Data/weiyumei/deeplearning/dataset/data/train',
    'Directory where the original dataset is stored.')
tf.app.flags.DEFINE_string(
    'output_name', 'TsingHua_Tencent_train',
    'Basename used for TFRecords output files.')
tf.app.flags.DEFINE_string(
    'output_dir', r'/mnt/Data/weiyumei/deeplearning/dataset/data/Tfrecords',
    'Output directory where to store TFRecords files.')


def main(_):
    # print('FLAGS.dataset_dir',FLAGS.dataset_dir)

    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with --dataset_dir')
    print('Dataset directory:', FLAGS.dataset_dir)
    print('Output directory:', FLAGS.output_dir)

    if FLAGS.dataset_name == 'TsingHua_Tencent':
        pascalvoc_to_tfrecords.run(FLAGS.dataset_dir, FLAGS.output_dir, FLAGS.output_name)
    else:
        raise ValueError('Dataset [%s] was not recognized.' % FLAGS.dataset_name)

if __name__ == '__main__':
    tf.app.run()

然后开始数据映射和转换
pascalvoc_to_tfrecords.py


"""
Each validation TFRecord file contains ~500 records. Each training TFREcord
file contains ~1000 records. Each record within the TFRecord file is a
serialized Example proto. The Example proto contains the following fields:

    image/encoded: string containing JPEG encoded image in RGB colorspace
    image/height: integer, image height in pixels
    image/width: integer, image width in pixels
    image/channels: integer, specifying the number of channels, always 3
    image/format: string, specifying the format, always'JPEG'


    image/object/bbox/xmin: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/xmax: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/ymin: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/ymax: list of float specifying the 0+ human annotated
        bounding boxes
    image/object/bbox/label: list of integer specifying the classification index.
    image/object/bbox/label_text: list of string descriptions.

Note that the length of xmin is identical to the length of xmax, ymin and ymax
for each example.
"""
import os
import sys
import random

import numpy as np
import tensorflow as tf

import Basic
import glob
import anno_func
from dataset_utils import int64_feature, float_feature, bytes_feature
from pascalvoc_common import TSINGHUA_TECENT

# Original dataset organisation.
# TFRecords convertion parameters.
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 6105
DATADIR=Basic.DATADIR
anns=Basic.anns
target_classes=Basic.target_classes


def _process_image(filename):
    """
    根据读取图片的信息,获取其shape,用gfile读取时必须用sess读取才可以获得其shape的值
    Args:
      filename: string, path to an image file e.g., '/path/to/example.JPG'.
    """
    image_data = tf.gfile.FastGFile(filename, 'rb').read()
    # 解码RGB JPEG.
    image = tf.image.decode_jpeg(image_data)
    #获取shpape
    with tf.Session() as sess:
        image=sess.run(image)
        shape = image.shape
    '''读取单张图片的信息,将图片 filename='../../../3545/jpg' 截取成 imgid='3546'  '''
    imgid = filename.split('.')[0].split('/')[-1]
    #'根据图片名称映射图片中的标注信息'
    imgdata = anno_func.load_img(anns, DATADIR, imgid)
    imgdata_draw, mask_ellipse, img = anno_func.draw_all(anns, DATADIR, imgid, imgdata)
    img_info = img['objects']
    label_classes = []
    label_text = []
    bbox_list = []
    #    将所有的xmin划为一列,
    ymin = []
    xmin = []
    ymax = []
    xmax = []
    for obj in range(len(img_info)):
        label_text.append(img_info[obj]['category'].encode())
        bbox = img_info[obj]['bbox']
        assert len(bbox) == 4
        bbox_list.append([bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']])
        label_classes.append(int(target_classes[img_info[obj]['category']][0]))
    return image_data,  label_classes, label_text,bbox_list,shape

#将获取到的图片信息转换成 example格式
def _convert_to_example(image_data, labels, labels_text, bboxes, shape):
    """Build an Example proto for an image example.

    Args:
      image_data: string, JPEG encoding of RGB image;
      labels: list of integers, identifier for the ground truth;
      labels_text: list of strings, human-readable labels;
      bboxes: list of bounding boxes; each box is a list of integers;
          specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
          to the same label as the image label.
      shape: 3 integers, image shapes in pixels.
    Returns:
      Example proto
    """
    xmin = []
    ymin = []
    xmax = []
    ymax = []
    for b in bboxes:
        assert len(b) == 4
        # pylint: disable=expression-not-assigned
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        # pylint: enable=expression-not-assigned

    image_format = b'JPEG'
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(shape[0]),
            'image/width': int64_feature(shape[1]),
            'image/channels': int64_feature(shape[2]),
            'image/shape': int64_feature(list(shape)),
            'image/object/bbox/xmin': float_feature(xmin),
            'image/object/bbox/xmax': float_feature(xmax),
            'image/object/bbox/ymin': float_feature(ymin),
            'image/object/bbox/ymax': float_feature(ymax),
            'image/object/bbox/label': int64_feature(labels),
            'image/object/bbox/label_text': bytes_feature(labels_text),
            # 'image/object/bbox/difficult': int64_feature(difficult),
            # 'image/object/bbox/truncated': int64_feature(truncated),
            'image/format': bytes_feature(image_format),
            'image/encoded': bytes_feature(image_data)}))
    return example

#将图片一张一张写入tf.records
def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
    """Loads data from image and annotations files and add them to a TFRecord.

    Args:
      dataset_dir: Dataset directory;
      name: Image name to add to the TFRecord;
      tfrecord_writer: The TFRecord writer to use for writing.
    """
    image_data, labels, labels_text, bboxes, shape=  _process_image(name)
    example = _convert_to_example(image_data, labels, labels_text,bboxes, shape)
    tfrecord_writer.write(example.SerializeToString())


def _get_output_filename(output_dir, name, idx):
    return '%s/%s_%03d.tfrecord' % (output_dir, name, idx)


def run(dataset_dir, output_dir, name=' TsingHua_Tencent', shuffling=False):
    """Runs the conversion operation.

    Args:
      dataset_dir: The dataset directory where the dataset is stored.
      output_dir: Output directory.
    """
    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)

    # Dataset filenames, and shuffling.
    # 获取所有的图片
    IMG_PATH=Basic.IMG_PATH
    filenames = glob.glob(IMG_PATH + "/*.jpg")
    # filenames = Basic.DATADIR + '/annotations.json'
    # path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)
    # filenames = sorted(os.listdir(path))
    # path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)
    # filenames = sorted(os.listdir(path))
    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames)

    # Process dataset files.
    i = 0
    fidx = 0
    while i < len(filenames):
        # Open new TFRecord file.
        tf_filename = _get_output_filename(output_dir, name, fidx)
        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
            j = 0
            while i < len(filenames) and j < SAMPLES_PER_FILES:
                sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))
                sys.stdout.flush()

                filename = filenames[i]

                # img_name = filename[:-4]
                _add_to_tfrecord(dataset_dir, filename, tfrecord_writer)
                i += 1
                j += 1
            fidx += 1


    print('\nFinished converting the dataset!')

dataset_utils.py中转换图片信息的代码封装

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import tarfile

from six.moves import urllib
import tensorflow as tf

def int64_feature(value):
    """Wrapper for inserting int64 features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def float_feature(value):
    """Wrapper for inserting float features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

pascalvoc_common.py中保存的信息

TSINGHUA_TECENT = {'i1': (1, 1),
                   'i10': (2, 2),
                   'i11': (3, 3),
                   'i12': (4, 4),
                   'i13': (5, 5),
                   'i14': (6, 6),
                   }

Baisc中获取数据集的基本信息

import os
import json
import Basic
# 获取数据集
# current=os.getcwd()
# print('current',current)
# DATADIR = os.path.abspath(os.path.join(current,os.path.pardir,os.path.pardir,'dataset/data'))
DATADIR="/mnt/Data/weiyumei/deeplearning/dataset/data"
print('DATADIR',DATADIR)

IMG_PATH=DATADIR+'/train'
MARK_PATH=DATADIR+'/marks' 
# 位置标记文件
filedir=DATADIR+'/annotations.json'
# 每张图片名称所关联的id
ids=open(DATADIR+'/test/ids.txt').read().splitlines()
anns=json.loads(open(filedir).read())
# print('anns',anns['types'])
# 为类别生成数字编号
targets=anns['types']
target_classes=dict()
for i,name in enumerate(targets):
    target_classes[name]=(i+1,i+1)
#     当读取不到数据,为背景图
target_classes['None']=0
#labels type为数字
labels=target_classes

猜你喜欢

转载自blog.csdn.net/weiyumeizi/article/details/82108329
今日推荐