TensorFlow输入数据处理框架


        如图,大致为输入数据处理流程示意图。输入数据处理第一步为获取存储训练数据的文件列表,在该图中文件列表为{A,B,C}。通过tf.train.string_input_producer函数可以选择性将文件顺序打乱,并加入输入队列。tf.train.string_input_producer函数会生成并维护一个输入文件队列,不同线程中的文件读取函数可以共享这个文件队列。

        在读取样例程序后,需要对图像进行预处理。预处理的过程也会通过tf.train.shuffle_batch提供的机制并行的跑在多个线程中。输入数据处理流程的最后通过tf.train.shuffle_batch函数将处理好的单个输入样例整理成batch提供给神经网络输入层。

import tensorflow as tf

#创建文件列表
files = tf.train.match_filenames_once("Records/output.tfrecords")
#创建文件输入队列
filename_queue = tf.train.string_input_producer(files, shuffle=False) 
# 读取文件。
# 解析数据。假设image是图像数据,label是标签,height、width、channels给出了图片的维度
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)

# 解析读取的样例。
features = tf.parse_single_example(
    serialized_example,
    features={
        'image': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64),
        'height': tf.FixedLenFeature([], tf.int64),
        'width': tf.FixedLenFeature([], tf.int64),
        'channels': tf.FixedLenFeature([], tf.int64)
    })
image, label = features['image'], features['label']
height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)
channels = tf.cast(features['channels'], tf.int32)

# 从原始图像中解析出像素矩阵,并根据像素尺寸还原图像
decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
decoded_image.set_shape([height, width, channels])
#定义神经网络输入层图片的大小
image_size = 299
# preprocess_for_train函数是对图片进行预处理的函数
distorted_image = preprocess_for_train(decoded_image, image_size, image_size,
                                       None)

#将处理后的图像和标签通过tf.train.shuffle_batch整理成神经网络训练时需要的batch
min_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size

image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
                                                    batch_size=batch_size, 
                                                    capacity=capacity, 
                                                    min_after_dequeue=min_after_dequeue)
# 定义神经网络的结构及优化过程。image_batch可以作为输入提供给神经网络的输入层
#label_batch则提供了输入batch中样例的正确答案
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

#声明会话并运行神经网络优化过程
with tf.Session() as sess:
    #神经网络训练准备工作,这些工作包括变量初始化、线程启动
    sess.run(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    # 神经网络训练过程
    for i in range(TRAINING_ROUNDS):
        sess.run(train_step)
        
    #停止所有线程
    coord.request_stop()
    coord.join()

其代码如下:

猜你喜欢

转载自blog.csdn.net/dz4543/article/details/79658105