Tensorflow入门:tfrecord读写cifar10数据集

将cifar10转为tfrecord文件:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/7/18 15:00
# @Author  : HJH
# @Site    : 
# @File    : covert_cifar10.py
# @Software: PyCharm

import tensorflow as tf
import os
import sys
import numpy as np
import pickle as p
from PIL import Image

_NUM_TRAIN_FILES = 5
LABELS_FILENAME = 'label.txt'
_CLASS_NAMES = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]


def _int64_feature(value):
    if not isinstance(value, (tuple, list)):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _image_to_tfexample(image_data, image_format, class_id):
    return tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': _bytes_feature(image_data),
        'image/format': _bytes_feature(image_format),
        'image/class/label': _int64_feature(class_id)
    }))


def _add_to_tfrecord(filename, tfrecord_writer, offset=0):
    with tf.gfile.Open(filename, 'rb') as f:
        # get python version
        if sys.version_info < (3,):
            data = p.load(f)
        else:
            data = p.load(f, encoding='bytes')

    images = data[b'data']
    num_images = images.shape[0]
    images = images.reshape((num_images, 3, 32, 32))
    labels = data[b'labels']

    with tf.Graph().as_default():
        for j in range(num_images):
            sys.stdout.write('\r>> Reading file [%s] image %d/%d' % (
                filename, offset + j + 1, offset + num_images))
            sys.stdout.flush()

            image = np.squeeze(images[j]).transpose((1, 2, 0))
            image = Image.fromarray(image)
            image = image.resize((227, 227))
            # image.save('../images/image/' + str(j) + '.png')
            image = image.tobytes()
            label = labels[j]

            example = _image_to_tfexample(image, b'png', label)
            tfrecord_writer.write(example.SerializeToString())

    return offset + num_images


def _get_output_filename(dataset_dir, split_name):
    return '%s/cifar10_%s.tfrecord' % (dataset_dir, split_name)


def run(dataset_dir):
    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)

    training_filename = _get_output_filename(dataset_dir, 'train')
    testing_filename = _get_output_filename(dataset_dir, 'test')

    if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
        print('Dataset files already exist. Exiting without re-creating them.')
        return

    # First, process the training data:
    with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
        offset = 0
        for i in range(_NUM_TRAIN_FILES):
            filename = os.path.join(dataset_dir,
                                    'data_batch_%d' % (i + 1))  # 1-indexed.
            offset = _add_to_tfrecord(filename, tfrecord_writer, offset)

    # Next, process the testing data:
    with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
        filename = os.path.join(dataset_dir,
                                'test_batch')
        _add_to_tfrecord(filename, tfrecord_writer)

    # Finally, write the labels file:
    labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
    write_label_file(labels_to_class_names, dataset_dir)

    print('\nFinished converting the Cifar10 dataset!')


if __name__ == '__main__':
    run('../images/cifar_10/')

读取tfrecord文件:

使用两种不同的方式读取tfrecord文件,其中tfr_reader是使用TFRecordReader读取,tfr_data是使用TFRecordDataset读取。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/7/19 11:33
# @Author  : HJH
# @Site    : 
# @File    : read_cifar10.py
# @Software: PyCharm

import tensorflow as tf
from PIL import Image


class Read(object):

    def __init__(self, file_dir, batch_size):
        self.FILE_DIR = file_dir
        self.BATCH_SIZE = batch_size

    def _read_and_decode(self, serialized_example):
        keys_to_features = {
            'image/encoded': tf.FixedLenFeature([], tf.string, default_value=''),
            'image/format': tf.FixedLenFeature([], tf.string, default_value='png'),
            'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64))
        }

        features = tf.parse_single_example(serialized_example, keys_to_features)
        image = tf.decode_raw(features['image/encoded'], tf.uint8)
        image = tf.cast(image, tf.float32)
        image = tf.reshape(image, [227, 227, 3])
        image = tf.image.per_image_standardization(image)
        label = features['image/class/label']
        label = tf.one_hot(label, 10, 1, 0)
        return image, label

    def tfr_reader(self, min_after_dequeue=1000):
        files = tf.train.match_filenames_once([self.FILE_DIR])
        filename_queue = tf.train.string_input_producer(files, shuffle=True)

        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        image, label = self._read_and_decode(serialized_example)

        capacity = min_after_dequeue + 3 * self.BATCH_SIZE
        image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=self.BATCH_SIZE, num_threads=1,
                                                          capacity=capacity, min_after_dequeue=min_after_dequeue)
        return image_batch, label_batch

    def tfr_data(self, shuffle=True):
        data = tf.data.TFRecordDataset(self.FILE_DIR)
        data = data.map(self._read_and_decode).repeat()
        data = data.batch(self.BATCH_SIZE)
        if shuffle:
            data = data.shuffle(buffer_size=10000)

        iterator = data.make_one_shot_iterator()
        image, label = iterator.get_next()
        return image, label

猜你喜欢

转载自blog.csdn.net/M_Z_G_Y/article/details/81387575