TensorFlow eval被阻塞的可能原因

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/l460133921/article/details/87558304

问题描述

在TensorFlow Session作用域中通过eval打印出Tensor,发现eval被阻塞(never end)。

复现场景

#encoding=utf-8
import tensorflow as tf

def read_tfrecord():
    record_filename = "./data/dog_image.tfrecord"

    tf_record_filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(record_filename))
    tf_record_reader = tf.TFRecordReader()
    key, tf_record_serialized = tf_record_reader.read(tf_record_filename_queue)
    tf_record_feature = tf.parse_single_example(tf_record_serialized,
                                                features={
                                                    'label': tf.FixedLenFeature([], tf.string),
                                                    'image': tf.FixedLenFeature([], tf.string)
                                                })
    tf_record_image = tf.decode_raw(tf_record_feature['image'], tf.uint8)
    tf_record_label = tf.cast(tf_record_feature['label'], tf.string)
    return tf_record_image, tf_record_label

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()
    thread = tf.train.start_queue_runners(sess=sess, coord=coord)
    tf_record_image1, tf_record_label1 = read_tfrecord()
	
	print("before eval")
    print(tf_record_image1.eval()) #A
    print(tf_record_label1.eval()) #B
    print("end eval") #C
    coord.request_stop()
    coord.join(thread)

执行此段代码会发现代码A和B处不会出现任何输出,而且代码C处也未输出,说明代码阻塞在eval函数执行处。

解决方法

把上述代码中的read_tfrecord函数改为如下形式即可解决。

record_filename = "./data/dog_image.tfrecord"
#put queue out of read_tfrecord function
tf_record_filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(record_filename))
    
def read_tfrecord():
    tf_record_reader = tf.TFRecordReader()
    key, tf_record_serialized = tf_record_reader.read(tf_record_filename_queue)
    tf_record_feature = tf.parse_single_example(tf_record_serialized,
                                                features={
                                                    'label': tf.FixedLenFeature([], tf.string),
                                                    'image': tf.FixedLenFeature([], tf.string)
                                                })
    tf_record_image = tf.decode_raw(tf_record_feature['image'], tf.uint8)
    tf_record_label = tf.cast(tf_record_feature['label'], tf.string)
    return tf_record_image, tf_record_label

具体原因未知,希望有知道的小伙伴告诉一下。

另一个原因

You must call tf.train.start_queue_runners(sess) before you call train_data.eval() or train_labels.eval().

This is a(n unfortunate) consequence of how TensorFlow input pipelines are implemented: the tf.train.string_input_producer(), tf.train.shuffle_batch(), and tf.train.batch() functions internally create queues that buffer records between different stages in the input pipeline. The tf.train.start_queue_runners() call tells TensorFlow to start fetching records into these buffers; without calling it the buffers remain empty and eval() hangs indefinitely.

参考

  1. https://stackoverflow.com/questions/38589255/tensorflow-eval-never-ends

猜你喜欢

转载自blog.csdn.net/l460133921/article/details/87558304
今日推荐