使用tf.data建立数据通道

使用tf.data建立数据通道

动 机

在机器学习项目中构建输入管道总是漫长而痛苦的,并且比构建实际模型需要更多的时间。在本教程中,我们将学习如何使用TensorFlow的数据集模块tf.data为图像和文本构建有效的管道。

目 标

  • 学习如何使用tf.data并练习
  • 建立高效的加载图像和处理图像的通道
  • 建立高效的文本处理通道,包括如何建立词库

内容目录

一、tf.data概述

  • 使用Text Example介绍tf.data
  • 迭代和转换
  • 为什么使用可初始化的迭代器
  • 代码示例地址

二、建立图像数据通道

三、建立文本数据通道

  • 文件格式
  • 数据合并
  • 创建词汇表
  • 创建填充批次
  • 计算句子大小
  • 高级用法-提取字符

四、最佳实践

  • 打乱和重复
  • 使用多线程实现并行化
  • 预取数据
  • 操作顺序

一、tf.data概述

  官方资源

 使用Text Example介绍tf.data

新建file.txt文件,包含语句

I use Tensorflow
You use PyTorch
Both are great

 使用tf.dataAPI读取文件:

dataset = tf.data.TextLineDataset('file.txt')

迭代和转换

dataset是Tensorflow的图节点,其包含着读取文件的指令。如果我们想要读取文件,我们需要初始化图并在会话中评估这个节点。尽管这个听起来很复杂,实际上恰恰相反。现在甚至数据集对象也是图的一部分,因此你不需要担心如何将数输入模型。

我们需要增加一些额外的代码,来完成工作。首先,我们创建一个基于整个数据集的迭代器对象。

扫描二维码关注公众号,回复: 5870888 查看本文章
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

The one_shot_iterator method creates an iterator that will be able to iterate once over the dataset. In other words, once we reach the end of the dataset, it will stop yielding elements and raise an Exception.

现在,next_element 是图的节点。每一次执行,它将包含着迭代器的下一个元素。执行如下

with tf.Session() as sess:
    for i in range(3):
        print(sess.run(next_element))

既然已经了解tf.data API的基本原理,下面介绍一些先进的技巧。

例如

import tensorflow as tf
print(tf.__version__)
dataset = tf.data.TextLineDataset('file.txt')
#the value in dataet :‘I use Tensorflow’->['I', 'use', 'Tensorflow']
#'You use PyTorch'->['You', 'use', 'PyTorch']
#基于分隔符分割输入的每个元素
#map函数映射输入函数到整个数据集
dataset = dataset.map(lambda string: tf.string_split([string]).values)

#将数据集中连续的元素以batch_size为单位集合成批次
dataset = dataset.batch(2)
#预取数据,即它总是使得一个批次的数据准备被加载。
dataset = dataset.prefetch(1)
#创建基于整个数据集的迭代器
iterator = dataset.make_one_shot_iterator()
#使用get_next()方法取出元素,每次执行,next_element保存迭代器的下一个元素
next_element = iterator.get_next()

with tf.Session() as sess:
    print(sess.run(next_element))

为什么使用初始化迭代

通过初始化节点,相当于重新加载数据(make_one_shot仅执行一个epoch),我们可以选择从头开始训练。这对于我们执行多次epoch操作,极为有利。

dataset = tf.data.TextLineDataset('file.txt')
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
init_op = iterator.initializer

with tf.Session() as sess:
  #初始化迭代器
  sess.run(init_op)
  print(sess.run(next_element))
  print(sess.run(next_element))
  #移动迭代器到最初
  sess.run(init_op)
  print(sess.run(next_element))

如何建立图像数据通道

假设我们已经有了一个包含所有JPEG图像名称的列表和一个与之对应的标签列表。

通道建立步骤如下:

  1. 从文件名和标签的切片创建数据集
  2. 使用长度等于数据集大小的buffer size,打乱数据集。这确保了良好的改组。
  3. 从图像文件名中解析像素值。使用多线程提升预处理的速度
  4. (可选操作)图像数据扩增。使用多线程提升预处理的速度。
  5. 批量处理图片
  6. 预取一个批次以确保批处理可以随时使用

tf.data.Dataset的方法的输入为其内部的数据。

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.shuffle(len(filenames))
dataset = dataset.map(parse_function, num_parallel_calls=4)
dataset = dataset.map(train_preprocess, num_parallel_calls=4)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)

parse_function功能如下:

  1. 读取文件内容
  2. 使用JPEG图像格式解码
  3. 转化为0到1的浮点值
  4. 修改尺寸到(64, 64)
def parse_function(filename, label):
    image_string = tf.read_file(filename)
    
    #Don't use tf.image.decode_image, or the output shape will be undefined.
    image = tf.image.decode_jepg(image_string, channels)
    
    #This will convert to float values in [0, 1]
    image = tf.image.convert_image_dtype(image, tf.float32)

    image = tf.image.resize_images(image, [64, 64])
    return resized_image, label

函数 train_preprocess(optionally)可用于执行数据扩增。

  • 以1/2的概率水平翻转图像
  • 应用随机亮度和饱和度
def train_preprocess(image, label):
  image  tf.image.random_flip_left_right(image)
  
  image = tf.image.random_brightness(image, max_delta=32)
  image = tf.image.random_saturation(image, lowe=0.5, upper=1.5)
  
  #Make sure the image is still in [0, 1]
  image = tf.clip_by_value(image, 0.0, 1.0)
  return image, label

建立文本数据输入通道https://cs230-stanford.github.io/tensorflow-input-data.html#an-overview-of-tfdata

猜你喜欢

转载自blog.csdn.net/tianzhiya121/article/details/89206421