深度学习系列教程(六)tf.data API 使用方法介绍

转载自https://zhuanlan.zhihu.com/p/32649553。谢谢作者辛苦整理。若侵权,告知即删。

倾心之作!天学网AI学院名师团队“玩转TensorFlow与深度学习模型”系列文字教程,本周带来tf.data API 使用方法介绍!

该教程通过知识点讲解+答疑指导相结合的方式,让大家循序渐进的了解深度学习模型并通过实操演示掌握相关框架及TensorFlow工具使用。

大家在学习和实操过程中,有任何疑问都可以通过学院微信交流群进行提问,有导师和助教、大牛等为您解答哦。(入群方式在文末

第六篇的教程主要内容:TensorFlow 数据导入 (tf.data API 使用介绍)。

tf.data 简介

以往的TensorFLow模型数据的导入方法可以分为两个主要方法,一种是使用feed_dict另外一种是使用TensorFlow中的Queues。前者使用起来比较灵活,可以利用Python处理各种输入数据,劣势也比较明显,就是程序运行效率较低;后面一种方法的效率较高,但是使用起来较为复杂,灵活性较差。

Dataset作为新的API,比以上两种方法的速度都快,并且使用难度要远远低于使用Queuestf.data中包含了两个用于TensorFLow程序的接口:DatasetIterator

Dataset(数据集) API 在 TensorFlow 1.4版本中已经从tf.contrib.data迁移到了tf.data之中,增加了对于Python的生成器的支持,官方强烈建议使用Dataset API 为 TensorFlow模型创建输入管道,原因如下:

  • 与旧 API(feed_dict 或队列式管道)相比,Dataset API 可以提供更多功能。
  • Dataset API 的性能更高。
  • Dataset API 更简洁,更易于使用。

将来 TensorFlow 团队将会将开发中心放在Dataset API而不是旧的API上。

Dataset

Dataset表示一个元素的集合,可以看作函数式编程中的 lazy list, 元素是tensor tuple。创建Dataset的方式可以分为两种,分别是:

  • Source
  • Apply transformation

Source

这里 source 指的是从tf.Tensor对象创建Dataset,常见的方法又如下几种:

tf.data.Dataset.from_tensors((features, labels))
tf.data.Dataset.from_tensor_slices((features, labels))
tf.data.TextLineDataset(filenames)
tf.data.TFRecordDataset(filenames)

作用分别为:从一个tensor tuple创建一个单元素的dataset;从一个tensor tuple创建一个包含多个元素的dataset;读取一个文件名列表,将每个文件中的每一行作为一个元素,构成一个dataset;读取硬盘中的TFRecord格式文件,构造dataset。

Apply transformation

第二种方法就是通过转化已有的dataset来得到新的dataset,TensorFLow tf.data.Dataset支持很多中变换,在这里介绍常见的几种:

dataset.map(lambda x: tf.decode_jpeg(x))
dataset.repeat(NUM_EPOCHS)
dataset.batch(BATCH_SIZE)

以上三种方式分别表示了:使用map对dataset中的每个元素进行处理,这里的例子是对图片数据进行解码;将dataset重复一定数目的次数用于多个epoch的训练;将原来的dataset中的元素按照某个数量叠在一起,生成mini batch。

TensorFlow 1.4 版本中还允许用户通过Python的生成器构造dataset,如:

def generator():
  while True:
    yield ...

dataset = tf.data.Dataset.from_generator(generator, tf.int32)

将以上代码组合起来,我们可以得到一个常用的代码片段:

# 从一个文件名列表读取 TFRecord 构成 dataset
dataset = TFRecordDataset(["file1.tfrecord", "file2.tfrecord"])
# 处理 string,将 string 转化为 tf.Tensor 对象
dataset = dataset.map(lambda record: tf.parse_single_example(record))
# buffer 大小设置为 10000,打乱 dataset
dataset = dataset.shuffle(10000)
# dataset 将被用来训练 100 个 epoch
dataset = dataset.repeat(100)
# 设置 batch size 为 128
dataset = dataset.batch(128)

Iterator

定义好了数据集以后可以通过Iterator接口来访问数据集中的tensor tuple,iterator保持了数据在数据集中的位置,提供了访问数据集中数据的方法。

可以通过调用 dataset 的 make iterator 方法来构建 iterator。

API 支持以下四种 iterator,复杂程度递增:

  • one-shot
  • initializable
  • reinitializable
  • feedable

one-shot

one-shot iterator 谁最简单的一种 iterator,仅支持对整个数据集访问一遍,不需要显式的初始化。one-shot iterator 不支参数化。以下代码使用tf.data.Dataset.range生成数据集,作用与 python 中的 range 类似。

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

initializable

Initializable iterator 要求在使用之前显式的通过调用iterator.initializer操作初始化,这使得在定义数据集时可以结合tf.placeholder传入参数,如:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
  value = sess.run(next_element)
  assert i == value

# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
  value = sess.run(next_element)
  assert i == value

reinitializable

reinitializable iterator 可以被不同的 dataset 对象初始化,比如对于训练集进行了shuffle的操作,对于验证集则没有处理,通常这种情况会使用两个具有相同结构的dataset对象,如:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.Iterator.from_structure(training_dataset.output_types,
                                   training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op)
  for _ in range(50):
    sess.run(next_element)

feedable

feedable iterator 可以通过和tf.placeholder结合在一起,同通过feed_dict机制来选择在每次调用tf.Session.run的时候选择哪种Iterator。它提供了与 reinitilizable iterator 类似的功能,并且在切换数据集的时候不需要在开始的时候初始化iterator,还是上面的例子,通过tf.data.Iterator.from_string_handle来定义一个 feedable iterator,达到切换数据集的目的:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})
代码示例

这里举一个读取、解码图片,并且将图片的大小进行调整的例子:

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

猜你喜欢

转载自blog.csdn.net/u012911202/article/details/84821070