tensorflow(八):数据处理

版权声明:转载请注明出处~ https://blog.csdn.net/sinat_31425585/article/details/87909114

1、制作tfrecords

使用tf.train.Example来对数据和标签进行封装,然后采用tf.python_io.TFRecordWriter方法进行写操作。

import tensorflow as tf
import os
import PIL.Image as Image

cwd = './card_data_v1.0'
classes = {'four', 'nine', 'one', 'rectangle' }
writer = tf.python_io.TFRecordWriter('train.tfrecords')
for index, name in enumerate(classes):
    class_path = cwd + '/' name + '/'
    for img_name in os.listdir(class_path):
        img_path = class_path + img_name
        img = Image.open(img_path)
        img = img.resize((256, 256))
        img_raw = img.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            "img_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                
            }))
        writer.write(example.SerializeToString())

writer.close()

2、数据载入

先使用tf.train.string_input_producer将tfrecords文件输出到一个输入管道队列,然后采用tf.TFRecorderReader方法进行读取,最后用tf.parse_single_example进行解析,这样就可以将数据和类别标签解析出来。

注意 tf.train.string_input_producer格式为:

tf.train.string_input_producer(file_name, shuffle=False, num_epochs=5)

第一个参数为列表类型的文件名队列,第二个参数为是否打乱顺序,第三个为迭代epoch次数

def read_record(record_name):
    filename_queue = tf.train.string_input_producer([record_name])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example, features={
        "label": tf.FixedLenFeature([], tf.int64),
        "image_raw": tf.FixedLenFeature([], tf.string),
    })

    label = features["label"]
    image = features["image_raw"]
    image = tf.decode_raw(image, tf.uint8)
    image = tf.reshape(image, [256, 256, 3])
    label = tf.cast(label, tf.int32)
    return image, label

3、生成一个batch

使用tf.train.shuffle_batch或tf.train.batch方法,从数据中取出batch张图片和对应标签

image_batch, label_batch = tf.train.shuffle_batch([image, label],
                                                   batch_size=20, capacity=30, 
                                                   min_after_dequeue=10)

最后,测试程序为:

image, label = read_record("./card_data_v1.0/train.tfrecords")
print(image, label)
image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size=4, capacity=10, min_after_dequeue=5)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
thread = tf.train.start_queue_runners(sess=sess)
for i in range(10):
    images, labels = sess.run([image_batch, label_batch])
    print("batch shape = ", image.shape, "labels = ", labels)
    print("label = ", labels)

    for j in range(4):
        plt.subplot(1, 4, j+1)
        plt.axis("off")
        plt.imshow(images[j])

这里要注意一下:tf.train.start_queue_runners,没有调用的话,整个系统处于"停滞"状态,因为只有当调用了tf.train.start_queue_runners之后,文件名队列才会被加载到内存中,计算单元就可以拿到数据进行计算了。

总结一下:

第一步:使用tf.train.string_input_producer建立队列;

第二步:使用reader.read读取文件;

第三步:调用tf.train.start_queue_runners,将文件名队列加载到内存中;

最   后:通过sess.run()获取读取图片的结果。

参考资料:

[1] https://blog.csdn.net/xuan_zizizi/article/details/78431490

[2] https://blog.csdn.net/sinat_34474705/article/details/78966064

[3] 21个项目玩转深度学习

猜你喜欢

转载自blog.csdn.net/sinat_31425585/article/details/87909114