Tensorflow关于Dataset的一般操作

Dataset封装了很好的关于数据集的一些基本操作,在这里做一下总结。该对象的路径是:tensorflow.data.Dataset(这是1.4版本之后的)很大程度上参考了这篇博客

tf.data.Dataset.from_tensor_slices

tf.data.Dataset.from_tensor_slices表示从张量中直接读取数据。以最外维度作为一个分割界限。比如:

data = tf.data.Dataset.from_tensor_slices(
    np.ones(20).reshape(4, 5))

那么,data中的数据总共有4个,每个都是5*1的行向量。相当于进行了4次的切片操作。

tf.data.Dataset.make_one_shot_iterator

生成一个迭代器,用于便利所有的数据。一般用法如下:

tf.data.Dataset.make_one_shot_iterator.get_next()

每次列举出下一个数据集。
实例:

import tensorflow as tf
import numpy as np

data = tf.data.Dataset.from_tensor_slices(
    np.array([1, 2, 3, 4, 5]))

element = data.make_one_shot_iterator().get_next()  # 建立迭代器,并进行迭代操作

with tf.Session() as sess:
    try:
        while True:
            print(sess.run(element))
    except tf.errors.OutOfRangeError:
        print("Out range !")

以字典的方式处理数据

import tensorflow as tf
import numpy as np

a = np.array(['a', 'b', 'c', 'd', 'e'])
b = np.array([1, 2, 3, 4, 5])

# 分别切分数据,以字典的形式存储
data = tf.data.Dataset.from_tensor_slices(
    {
        "label1": a,
        "label2": b
    }
)

it=data.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    try:
        while True:
            print(sess.run(it))
    except tf.errors.OutOfRangeError:
        print("out of range")

输出结果

{'label2': 1, 'label1': b'a'}
{'label2': 2, 'label1': b'b'}
{'label2': 3, 'label1': b'c'}
{'label2': 4, 'label1': b'd'}
{'label2': 5, 'label1': b'e'}

常用的数据集操作

map函数

与python中的map作用类似,对输入的数据进行预处理操作。

import tensorflow as tf
import numpy as np

a = np.array([1, 2, 3, 4, 5])

data = tf.data.Dataset.from_tensor_slices(a)
# 注意在这里是返回的集合,原来的集合不变
data = data.map(lambda x: x ** 2)

it = data.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    try:
        while True:
            print(sess.run(it))
    except tf.errors.OutOfRangeError:
        print("out of range")
batch函数

batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:

dataset = dataset.batch(32)
shuffle函数

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:

dataset = dataset.shuffle(buffer_size=10000)
repeat函数

repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:

dataset = dataset.repeat(5)  # 重复5次数据

注意,必须指明重复的次数,否则会无限期的重复下去。

一种常规的用法:
dataset.shuffle(1000).repeat(10).batch(32)

把数据进行1000个为单位的乱序,重复10次,生成批次为32的batch

tf.data.TextLineDataset

这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。一般操作方式:

tf.data.TextLineDataset(file_path).skip(n)

读取文件,同时跳过前n行。

猜你喜欢

转载自blog.csdn.net/qq_35976351/article/details/80752535