tensorflow的数据读取一共有三种方式
- 供给数据(Feeding):在Tensorflow程序运行的每一步,让python代码来供给数据
- 从文件读取数据:在tensorflow图的起始,让一个输入管线从文件中读取数据
- 预加载数据:在tensorflow图中定义常量或变量来保存所有数据(仅仅适用于数据量比较小的情况)
供给数据
tensorflow的数据供给机制允许你在tensorflow运算图中将数据注入到任意张量中,因此,python运算可以把数据直接设置到tensorflow图中。然而却需要设置placeholder节点,通过run()函数输入feed_dict参数,可以启动运算过程。placeholder节点被声明的时候是未初始化的,也不包含数据,如果没有为它供给数据,则tensorflow运算的时候会产生错误。
在训练mnist手写字体识别时就使用到了feed_dict输入数据,部分代码如下。
完整代码见:https://github.com/skloisMary/LeNet-5.git。
def train(mnist):
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
y_ = LeNet.LeNet(x)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_, labels=y)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
train_step = tf.train.AdamOptimizer(RATE).minimize(cross_entropy_mean)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as session:
session.run(tf.global_variables_initializer())
x_train, y_train = mnist.train.images, mnist.train.labels
for i in range(EPOCH):
x_train, y_train = shuffle(x_train, y_train)
print('EPOCH:', i)
for offset in range(0, len(x_train), BATCH_SIZE):
batch_x, batch_y = x_train[offset: offset+BATCH_SIZE], y_train[offset:offset+BATCH_SIZE]
session.run(train_step, feed_dict={x:batch_x, y:batch_y})
validation_accuracy = session.run(accuracy,
feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
print('valiation accuracy:', validation_accuracy)
# test
test_accuracy = session.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print('test accuracy:', test_accuracy) #
从文件读取数据
上述供给数据机制只适用于小数据,遇到大量数据的时候,效率低下,所以就需要从文件中读取数据,虽然从文件中直接读取数据效率高,但相应地也比feed方式复杂,下面我们详细介绍。
首先,使用tf.train.string_input_producer()函数产生一个先入先出的队列queueu, 如上图所示,此操作是将文件名堆入队列中。函数格式为tf.train.string_input_producer(string_tensor,num_epochs=None,shuffle=True),num_epochs和shuffer两个可配置参数设置最大的训练迭代次数和文件名乱序,shuffle默认为True,会对文件名进行乱序处理。
filename = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
# 创建文件队列,不限制读取的数量,所以没有设置num_epochs
filename_queue = tf.train.string_input_producer(filename)
其次,创建文件阅读器reader从队列中取文件名并读取数据, 不同reader对应不同的文件结构。我们以CIFAR-10二进制数据集为例,使用tf.fixedLengthRecordReader函数从二进制文件中读取固定长度数据。接下来,使用reader的read方法从上述创建的文件队列filename_queue中读取数据,并用tf.decode_raw()函数将读取的value值转换成一个uint8的张量,然后就可以通过切片和转换得到需要的格式。即最后得到上述动图的Example Queue。
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
# decode_raw操作将一个字符串转换成一个uint8的张量
record_bytes = tf.decode_raw(value, tf.uint8)
# tf.strides_slice(input, begin, end, strides=None)截取[begin, end)之间的数据
result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],[label_bytes+image_byte]), [result.depth, result.height, result.width])
# convert from [depth, height, width] to [height, width, depth]
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
最后,进行批处理,从Example Queue中批量取出样本,使用tf.train.shuffle_batch来实现,返回一个batch_size大小的样本和样本标签。
min_fraction_of_examples_in_queue = 0.4
min_queue_examples =int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue)
image_batch, label_batch = tf.train.shuffle_batch([float_image, read_input.label], batch_size=batch_size, capacity= min_queue_examples + 3 * batch_size, min_after_dequeue=min_queue_examples)
在训练步骤运行之前,需要调用tf.train.start_queue_runner()函数启动输入管道的线程,填充样本到样本队列中,以便出队操作可以从队列中拿到样本。和tf.train.Coordinator()配合使用,当有错误时,它会完全关闭掉开启的threads。
with tf.Session() as session:
session.run(tf.global_variables_initializer())
# 创建一个线程协调器,用来管理session中启动的所有线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=session, coord=coord)
for index in range(EPOCHES):
_, loss_value, accuracy_value, summary = session.run([t_optimizer, t_loss, t_accuracy, merged])
if index % 1000 == 0:
print('index:', index, ' loss_value:', loss_value, ' accuracy_value:', accuracy_value)
# 终止所有线程的命令
coord.request_stop()
# 把threads加入主线程,等到threads结束
coord.join(threads)
Coordinator类用来管理Session中的多个线程,可以用来同时停止多个工作线程并且向等待所有工作进程终止的线程报告异常, 此线程捕获到异常之后就会终止所有的线程。
cifar-10的完整程序在:https://github.com/skloisMary/CIFAR-10.git