Tensorflow中创建自己的TFRecord格式数据集

参考文献《TensorFlow实战Google深度学习框架

TFRecord格式介绍

TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer格式(即二进制文件)存储,具体定义如下:

message Example{
    Features features = 1;
};
message Features{
    map<string,Feature> feature = 1;
};
message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

它实际上存储了一个从属性名到取值的字典。其中属性名为一个字符串,属性取值可以为字符串(ByteList),实数列表(FloatList)和整数列表(Int64List)。比如对于一幅图像而言,可以将图像的像素信息保存成一个字符串,将图像对应的标签保存成整数列表。

创建TFRecord文件

先导入一些必要的库:(jupyter-notebook中实现)

import numpy as np
import tensorflow as tf
from PIL import Image
import os
import matplotlib.pyplot as plt
%matplotlib inline 

数据预处理

我自己从网上下载了10张图片(3张猫,4张狗,3张马),分别存放在cat, dog和horse文件夹下,因为从网上下载的图片大小格式不统一,先将这些图片做预处理,函数如下(这里只是附上函数部分代码,文末会附上完整测试代码):

def preprocess(imageRawDir, imageDir):
    """
    images preprocess

    Arguments:
    imageRawDir -- directory of primary images.
    imageDir -- directory of processed images.

    Return: none.
    """
    imageNames = os.listdir(imageRawDir)
    label = imageDir.split("/")[-2] # directory format:"./data/cat/"
    for index, imageName in enumerate(imageNames):
        image = Image.open(os.path.join(imageRawDir,imageName))
        image = image.resize((256, 256))
        savePath = os.path.join(imageDir, str(label+"_"+str(index))+".jpg")
        image.save(savePath)  

预处理后的图片会保存在另一个指定的文件夹下。

写入到TFRecord文件

下面两个函数会在创建TFRecord文件的时候用到。因为如果不写成函数的形式,代码会很长,看起来也很头疼。

def _int64_feature(value):
    """
    generate int64 feature.
    """
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
    """
    generate byte feature.
    """
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

下面这个函数完成的功能是读取我们之前预处理后的所有图片,并依次将每张图片写入到TFRecord文件中,这里因为数据很少,只写入到了一个TFRecord文件中,当数据量很大时,也可以写入多个文件中。

def createRecord(imageDir):
    """
    create TFRecord data.

    Arguments:
    imageDir -- image directory.
    Return: none.
    """
    # create a writer to write TFRecord file
    writer = tf.python_io.TFRecordWriter(os.path.join(imageDir, "train.tfrecords"))
    classNames = ["cat", "dog", "horse"]

    for classIndex, className in enumerate(classNames):
        print "class name = ",className
        currentClassDir = os.path.join(imageDir,className)
        print "current dir = ",currentClassDir
        for index, imageName in enumerate(os.listdir(currentClassDir)):
            image = Image.open(os.path.join(currentClassDir,imageName))
            image_raw = image.tobytes() # convert image to binary format
            print index, imageName

            # write image data(pixel values and label) to Example Protocol Buffer
            example = tf.train.Example(features = tf.train.Features(feature = {
            "label": _int64_feature(classIndex),
            "image_raw": _bytes_feature(image_raw),
            }))

            # write an example to TFRecord file
            writer.write(example.SerializeToString())
    writer.close()

读取TFRecord文件

def readRecord(recordName):
    """
    read TFRecord data (images).

    Arguments:
    recordName -- the TFRecord file to be read.
    return: data saved in recordName (image and label).
    """
    filenameQueue = tf.train.string_input_producer([recordName])
    reader = tf.TFRecordReader()
    _, serializedExample = reader.read(filenameQueue)
    features = tf.parse_single_example(serializedExample, features={
        "label": tf.FixedLenFeature([], tf.int64),
        "image_raw": tf.FixedLenFeature([], tf.string)
    })

    label = features["label"]
    image = features["image_raw"]
    image = tf.decode_raw(image, tf.uint8)
    image = tf.reshape(image,[256,256,3])
    label = tf.cast(label, tf.int32)
    return image, label

注意,这里我们得到的返回值都是张量,需要在tensorflow中创建session后才能得到实际的数据。如下:

##test code
image, label =  readRecord("./data/train.tfrecords")
print image, label
imageBatch, labelBatch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5)


init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
thread = tf.train.start_queue_runners(sess=sess)
for i in range(10):
    #print image_batch.shape, label.shape
    images, labels = sess.run([imageBatch, labelBatch])
    print "batch shape = ", images.shape,"labels = ", labels
print "label = ", labels    
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.axis("off")
    plt.imshow(images[i])

输出结果:

batch shape =  (4, 256, 256, 3) labels =  [0 1 1 0]
batch shape =  (4, 256, 256, 3) labels =  [1 0 0 0]
batch shape =  (4, 256, 256, 3) labels =  [1 2 2 1]
batch shape =  (4, 256, 256, 3) labels =  [1 2 1 2]
batch shape =  (4, 256, 256, 3) labels =  [0 0 1 0]
batch shape =  (4, 256, 256, 3) labels =  [2 1 2 1]
batch shape =  (4, 256, 256, 3) labels =  [1 1 0 0]
batch shape =  (4, 256, 256, 3) labels =  [0 2 1 1]
batch shape =  (4, 256, 256, 3) labels =  [2 2 1 0]
batch shape =  (4, 256, 256, 3) labels =  [1 2 0 1]
label =  [1 2 0 1]

batch images
可以发现,最后一个的batch中的图像和标签是一一对应的(0: cat; 1: dog; 2: horse),说明我们已经成功从TFRecord文件中读出了数据。

完整样例代码

import numpy as np
import tensorflow as tf
from PIL import Image
import os
import matplotlib.pyplot as plt

currentDir = os.getcwd()
os.chdir(currentDir)
print currentDir

def preprocess(imageRawDir, imageDir):
    """
    images preprocess

    Arguments:
    imageRawDir -- directory of primary images.
    imageDir -- directory of processed images.

    Return: none.
    """
    imageNames = os.listdir(imageRawDir)
    label = imageDir.split("/")[-2] # directory format:"./data/cat/"
    for index, imageName in enumerate(imageNames):
        image = Image.open(os.path.join(imageRawDir,imageName))
        image = image.resize((256, 256))
        savePath = os.path.join(imageDir, str(label+"_"+str(index))+".jpg")
        image.save(savePath)    

##test code
catRawDir = "./data_raw/cat/"
catDir = "./data/cat/"
preprocess(catRawDir, catDir)

dogRawDir = "./data_raw/dog/"
dogDir = "./data/dog/"
preprocess(dogRawDir, dogDir)

horseRawDir = "./data_raw/horse/"
horseDir = "./data/horse/"
preprocess(horseRawDir, horseDir)


def _int64_feature(value):
    """
    generate int64 feature.
    """
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    """
    generate byte feature.
    """
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def createRecord(imageDir):
    """
    create TFRecord data.

    Arguments:
    imageDir -- image directory.
    Return: none.
    """
    writer = tf.python_io.TFRecordWriter(os.path.join(imageDir, "train.tfrecords"))
    classNames = ["cat", "dog", "horse"]

    for classIndex, className in enumerate(classNames):
        print "class name = ",className
        currentClassDir = os.path.join(imageDir,className)
        print "current dir = ",currentClassDir
        for index, imageName in enumerate(os.listdir(currentClassDir)):
            image = Image.open(os.path.join(currentClassDir,imageName))
            image_raw = image.tobytes() # convert image to binary format
            print index, imageName

            example = tf.train.Example(features = tf.train.Features(feature = {
            "label": _int64_feature(classIndex),
            "image_raw": _bytes_feature(image_raw),
            }))
            writer.write(example.SerializeToString())
    writer.close()


##test code
createRecord(os.path.join(currentDir, "data/"))


def readRecord(recordName):
    """
    read TFRecord data (images).

    Arguments:
    recordName -- the TFRecord file to be read.
    return: data saved in recordName (image and label).
    """
    filenameQueue = tf.train.string_input_producer([recordName])
    reader = tf.TFRecordReader()
    _, serializedExample = reader.read(filenameQueue)
    features = tf.parse_single_example(serializedExample, features={
        "label": tf.FixedLenFeature([], tf.int64),
        "image_raw": tf.FixedLenFeature([], tf.string)
    })

    label = features["label"]
    image = features["image_raw"]
    image = tf.decode_raw(image, tf.uint8)
    image = tf.reshape(image,[256,256,3])
    label = tf.cast(label, tf.int32)
    return image, label


##test code
image, label =  readRecord("./data/train.tfrecords")
print image, label
imageBatch, labelBatch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5)


##test code
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
thread = tf.train.start_queue_runners(sess=sess)
for i in range(10):
    #print image_batch.shape, label.shape
    images, labels = sess.run([imageBatch, labelBatch])
    print "batch shape = ", images.shape,"labels = ", labels
print "label = ", labels    
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.axis("off")
    plt.imshow(images[i])

猜你喜欢

转载自blog.csdn.net/sinat_34474705/article/details/78966064