Tensorflow 流水线并行读取数据

前言

一直以来都是用 tensorflow 框架实现深度学习的算法,在网络训练时有一个重要的问题就是训练数据的读取。tensorflow 支持流水线并行读取数据,这种方式将数据的读取和网络训练并行,数据读取效率和将所有数据载入内存后进行存取相当,却又不会增加内存开销,是很值得推荐的一种方式。这篇笔记就是总结一下自己在实际应用中的并行数据读取,留个备份,随时学习。

主要参考了 Google HDRnet 代码:https://github.com/mgharbi/hdrnet,CycleGAN 代码:https://github.com/vanhuyz/CycleGAN-TensorFlow

数据读取

HDRnet工程里的 data_pipeline.py 文件提供了非常清晰的流水线读取数据示例,在官方代码的基础上,可以很轻松地针对自己的应用实现一套数据读取接口,假设我们的训练数据存储在目录 training_data/input 和 training_data/output,input 存储网络输入,output 存储网络输出,一组训练样本的名称假设相同,均为二进制文件 *.dat,以下面代码为示例展示如何实现流水线并行数据读取:

def data_generator(params, data_path):
    filelist = os.listdir(data_path)      # 获取训练目录下的文件名列表
    if params.shuffle:
            random.shuffle(filelist)      # 随机打乱训练数据

    input_files = [os.path.join(dirname, 'input', f) for f in filelist if f.endswith('.dat')]        # 生成输入数据文件名列表
    output_files = [os.path.join(dirname, 'output', f) for f in filelist if f.endswith('.dat')]     # 生成目标输出文件名列表

   # 基于给定的文件名列表,创建先入先出的文件名队列,输入可以是多个文件名列表,输出对应的对个文件名队列 input_queue, output_queue
= tf.train.slice_input_producer( [input_files, output_files], shuffle=params.shuffle, seed=params.seed, num_epochs=params.num_epochs) input_reader = tf.read_file(input_queue)       # 创建 reader,读取输入数据 output_reader = tf.read_file(output_queue)      # 创建 reader,读取目标输出
   # 根据文件类型的不同解析数据,如果文件是图像,可以使用 tf.image.decode_jpeg 等函数解析
if os.path.splitext(input_files[0])[-1] == '.jpg': input = tf.image.decode_jpeg(input_reader, channels=3) else: input = tf.decode_raw(input_reader, data_type=tf.uint16)  # 如果是二进制信息存储,则可以使用 tf.decode_raw 函数解析 input = tf.reshape(input, [params.height, params.width, params.channel]) # 将数据 reshape 为正确的形状,此处以图像 (height, width, channel) 为例 if os.path.splitext(output_files[0])[-1] == '.jpg': output = tf.image.decode_jpeg(output_reader, channels=3) else: output = tf.decode_raw(output_reader, data_type=tf.uint16) input = tf.reshape(input, [params.height, params.width, params.channel])
   # 上面读取了单个输入和对应的目标输出,网络训练时如需数据增广,可以在读取单个训练对之后,使用函数对数据进行处理,扩大训练集 input, output
= augment_data(input, output) samples = {} # 将增广后的一对训练数据组织为字典的形式,便于后面组织成 batch samples['input'] = input samples['output'] = output if param.shuffle: # 创建批样例训练数据 samples = tf.train.shuffle_batch( sample, batch_size=params.batch_size, num_threads=params.nthreads, capacity=params.capacity, min_after_dequeue=params.min_after_dequeue) else: samples = tf.train.batch( sample, batch_size=params.batch_size, num_threads=params.nthreads, capacity=params.capacity) return samples # 返回一个 batch 的训练数据

代码中具体函数的接口可以通过 tensorflow 的文档查清。以上,只是声明了多线程的文件读取操作,并不会真正的读取数据,为了在会话执行时顺利地获取输入数据,需要使用 tf.train.start_queue_runners 来启动执行入队列操作的所有线程,具体过程包括:文件名入队到文件名队列样例入队到样例队列。示例代码如下:

params.shuffle = true
params.seed    = 1234
params.height  = 224
params.width   = 224
params.channel = 3
training_path  = 'dir/to/training/data'
training_samples = data_generator(params, training_path)
batch_inputs = training_samples['input']
batch_outputs = training_sample['output']

# 网络计算图创建
conv_1 = Conv2D(batch_inputs, ...)
...
conv_n = Conv2D(conv_n-1, ...)
output = tf.sigmoid(conv_n)

loss = tf.reduce_mean(tf.squared_difference(output, batch_outputs))
train_op = tf.minimize(loss,...)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  tf.train.start_queue_runners(sess = sess)
  sess.run(train_op)
  ...

上面的代码中,输入输出各只有一张图像,展示了如何实现流水线读取,以及如何使用读取出的数据。当输入或者输出包含多个文件时,例如,输入是图像和其语义分割图,可以在 data_generator 函数中,增加对语义分割图的读取,相对应的,多了 seg_files、seg_queue、seg_reader、seg_map 以及最后的 samples['seg_map'] = seg_map

同样,当输入数据是其它格式时,只需要根据对应的格式修改数据读取的代码接,例如 CycleGAN 中,训练数据存储为 tfrecord 格式,需要修改的其实就是对文件的读取部分。

我们都知道,tensorflow 在创建网络计算图时,通常需要为网络输入和目标输出先声明 placeolder,但是上面的第二段示例代码则是直接使用数据读取的输出构建网络计算图,是不是说采用这种方式就不能采用常见方法那样,先定义 placeholder,再在网络训练中使用 feed_dict 填充数据呢?答案是可以的,方法也和通常的做法没有太大区别,示例如下:

x = tf.placeholder(...)
y = tf.palceholder(...) conv_1
= Conv2D(y, ...) ...
loss = tf.reduce_mean(tf.squared_difference(net_y, y))

train_op = tf.minimize(loss, ...)

with tf.Session() as sess:

  sess.run(tf.global_variables_initializer())
  tf.train.start_queue_runners(sess=sess)

  samples = data_generator(params, training_path)
  sess.run(train_op, feed_dict={x: samples['input'], y: samples['output']})

和第一种方法的区别是 data_generator 是在会话 sess 中调用,而不是在构建网络计算图时

需要注意的是,上面的方式容错性比较差,主要是因为采用多线程方式读取数据,队列操作后台线程的生命周期无管理机制,线程出现异常会导致程序崩溃,比较常见的异常是文件名队列或者样例队列越界抛出的 tf.errors.OutOfRangeError。为了处理这种异常,HDRnetCycleGAN 工程代码中都使用 tf.train.Coordinator 创建了管理多线程声明周期的协调器,其工作原理是通过监控 tensorflow 所有后台线程,当有线程出现异常时,协调器的 should_stop 成员方法返回 True,循环结束,然后会话执行协调器的 request_stop 方法,请求所有线程安全退出。一套完整的示例代码如下:

params.shuffle = true
params.seed = 1234
params.height = 224
params.width = 224
params.channel = 3
training_path = 'dir/to/training/data'
training_samples = data_generator(params, training_path)
batch_inputs = training_samples['input']
batch_outputs = training_sample['output']
# 网络计算图创建
conv_1 = Conv2D(batch_inputs, ...)
...
conv_n = Conv2D(conv_n-1, ...)
output = tf.sigmoid(conv_n)
loss = tf.reduce_mean(tf.squared_difference(output, batch_outputs))
train_op = tf.minimize(loss,...)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  try:    
    while
not coord.should_stop():
      sess.run(train_op)
  except KeyboardInterrupt:  # 响应 Ctrl+C 停止训练
    coord.request_stop()
  except Exception as e:  # 后台线程出现异常
    coord.request_stop(e)
  finally: # 这一步总会执行
    save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) # 保存 checkpoint
    coord.request_stop()
    coord.join(threads)

总结

以上,介绍而 tensorflow 中如何使用多线程并行读取数据,如何在训练中使用读取的数据,以及如何在对多线程进行监视,增加网络训练的鲁棒性。分享给大家,也为自己学习。

猜你喜欢

转载自www.cnblogs.com/beshining/p/10162008.html