TensorFlow学习(十六):使用tf.data来创建输入流(下)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xierhacker/article/details/79698165

上面已经说过了怎么使用tf.data处理简单的数据输入,有了上面的基础之后,这部分使用tf.data来创建更加复杂更加贴近于现实的数据输入. 这里主要使用tfrecords来创建输入流。之后训练模型非常方便,要是想通过其他的方式进行输入操作的,可以参考官方文档。
这一节可以看做是TensorFlow学习(十五):使用tf.data来创建输入流(上)TensorFlow学习(十一):保存TFRecord文件 这两节的后续。
要是对于怎么生成tfrecords不熟悉的话,可以参考这两节来复习。

这里给出了一些常见的使用案例,代码存放在:LearningTensorFlow/11.TFRecord/

一.主要API

还是老样子,这里先把最主要的API列在这里,后面会用到这些API,先混个脸熟.

tf.data.TFRecordDataset 类

__init__(filenames,compression_type=None,buffer_size=None)
创建一个TFRecordDataset

参数:`

  • filenames: 一个 tf.string 类型的tensor里面包含一个或者多个TFRecord文件的文件名
  • compression_type: (可选) ,可以是"" (没有压缩), "ZLIB", 或者 "GZIP".
  • buffer_size: (Optional.) A tf.int64 scalar representing the number of bytes in the read buffer. 0 means no buffering.

map(map_func,num_parallel_calls=None)
作用:在这整个dataset里面使用map_func来映射,实际上我们用的时候,可以通过这个函数来装换为一般的dataset.也就是返回一个Dataset对象.

参数:

  • map_func: A function mapping a nested structure of tensors (having shapes and types defined by self.output_shapes and self.output_types) to another nested structure of tensors.
  • num_parallel_calls: (Optional.) A tf.int32 scalar tf.Tensor, representing the number elements to process in parallel. If not specified, elements will be processed sequentially.

apply(transformation_func)
在dataset上面应用一个转换函数。

dataset = (dataset.map(lambda x: x ** 2)
           .apply(group_by_window(key_func, reduce_func, window_size))
           .map(lambda x: x ** 3))

参数:
transformation_func: 接受一个Dataset 作为参数并且返回另外一个Dataset 的函数

当然这里还有一些batch(),shuffle()等等函数,这里就不讲了,上面一节有,这里的用法和上面一节是一样的。后面的例子可以清楚的看到。

tf.parse_single_example 函数

tf.parse_single_example(serialized,features,name=None,example_names=None)
作用:解析读入的单个Example proto.

Args:
serialized: 单个的序列化的Example.
features: A dict mapping feature keys to FixedLenFeature or VarLenFeature values.
name: A name for this operation (optional).
example_names: (Optional) A scalar string Tensor, the associated name. See _parse_single_example_raw documentation for more details.
Returns:
A dict mapping feature keys to Tensor and SparseTensor values.

Raises:
ValueError: if any feature is invalid.

二.例子

Ⅰ.从分散的文件产生dataset

Ⅱ.从TFRecord产生dataset

这里的操作是TensorFlow学习(十一):保存TFRecord文件.csv文件转为tfrecord文件的读取操作.

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt


#tfrecord 文件列表
file_list=["train.tfrecords"]

#创建dataset对象
dataset=tf.data.TFRecordDataset(filenames=file_list)

#定义解析和预处理函数
def _parse_data(example_proto):
    parsed_features=tf.parse_single_example(
        serialized=example_proto,
        features={
            "image_raw":tf.FixedLenFeature(shape=(),dtype=tf.string),
            "label":tf.FixedLenFeature(shape=(),dtype=tf.int64)
        }
    )

    # get single feature
    raw = parsed_features["image_raw"]
    label = parsed_features["label"]
    # decode raw
    image = tf.decode_raw(bytes=raw, out_type=tf.int64)
    image=tf.reshape(tensor=image,shape=[28,28])
    return image,label

#使用map处理得到新的dataset
dataset=dataset.map(map_func=_parse_data)
#使用batch_size为32生成mini-batch
#dataset = dataset.batch(32)

#创建迭代器
iterator=dataset.make_one_shot_iterator()

next_element=iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        image, label = sess.run(next_element)
        print(label)
        print(image.shape)
        print(label.shape)
        #plt.imshow(image)
        #plt.show()

猜你喜欢

转载自blog.csdn.net/xierhacker/article/details/79698165