tensorflow high level API---import data

一、基本机制

其实就是tf.data接口可以更好的处理大规模的数据和各种数据类型,还有处理复杂的转换。

(1)tf.data.Datasets代表了一个元素的序列,着每一个元素包含了一个或者多个张量实体。有两种创建数据集的方法:第一种(创造一个源)是通过Dataset.from_tensor_slices来构建一个数据集从一个or多个张量实体;第二种是应用一种转换,例如Dataset.batch()构建一个数据集从一个或者多个tf.data.Dataset实体。

(2)td.data.Iterator:提供了主要的一种从数据集上取得元素的方法。Iterator.get_next()可以获取数据集的下一个元素。

第一你需要定义一个源,例如,你可以在内存的张量上构建数据集,使用tf.data.Dataset.from_tensors()或者tf.data.Dataset.from_tensor_slices().当然如果在硬盘上你有以TFRecord形式存储的数据,你可以建立tf.data.TFRecordDataset。

一旦你有了dataset实体,你可以通过tf.data.Dataset实体的一些方法来转化成一个新的数据集。你可以通过Dataset.map()将每个元素进行转化,也可以通过Dataset.batch()将多个元素进行转换。

最常见的消耗数据的方法就是在数据集上建立一个迭代器,可以通过Dataset.make_one_shot_iterator()来每次获取数据集的一个元素。tf.data.Iterator提供了两个操作,第一个是Iterator.initializer确保你初始化或者再次初始化你的迭代状态,第二个是Iterator.get_next()获取下一个数据。

1.1 数据结构

一个数据集拥有很多个结构相同的元素,每个元素包含一个或者多个张量,称之为组件compoents.每一个组件有个tf.Dtype的属性可以表示元素的属性,tf.TensorShape表明了每一个静态元素的形状。Dataset.output_types和Dataset.output_shapes表明了每个数据集元素的组件的类型和结构。

案例:

from __future__ import absolute_import,division,print_function
import tensorflow as tf

dataset1=tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,10]))
print(dataset1.output_shapes)
print(dataset1.output_types)

dataset2=tf.data.Dataset.from_tensor_slices(
    (tf.random_uniform([4]),
     tf.random_uniform([4,100],maxval=100,dtype=tf.int32))
)
print(dataset2.output_types)
print(dataset2.output_shapes)

dataset3=tf.data.Dataset.zip((dataset1,dataset2))
print(dataset3.output_types)
print(dataset3.output_shapes)

结果:

(10,)
<dtype: 'float32'>
(tf.float32, tf.int32)
(TensorShape([]), TensorShape([Dimension(100)]))
(tf.float32, (tf.float32, tf.int32))
(TensorShape([Dimension(10)]), (TensorShape([]), TensorShape([Dimension(100)])))

每个元素的组件起个名字非常的方便,代码:

dataset2=tf.data.Dataset.from_tensor_slices(
    {'a':tf.random_uniform([4]),
     'b':tf.random_uniform([4,100],maxval=100,dtype=tf.int32)}
)
print(dataset2.output_types)
print(dataset2.output_shapes)

结果:

(10,)
<dtype: 'float32'>
{'a': tf.float32, 'b': tf.int32}
{'a': TensorShape([]), 'b': TensorShape([Dimension(100)])}

代码:你可以通过下面的方法对数据集的数据进行转变。

dataset1 = dataset1.map(lambda x: ...)

dataset2 = dataset2.flat_map(lambda x, y: ...)

# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)

1.2 创造迭代器

迭代器的类型:

one-shot、initializable、reinitializable、feedable

(1)one-shot最简单,不需要特殊的初始化操作,仅仅在数据集上迭代一次。这个支持处理几乎所有的现存的基于队列的输入流,但是不支持参数。Dataset.range()案例。只有这个可以被estimator使用。

from __future__ import absolute_import,division,print_function
import tensorflow as tf

dataset=tf.data.Dataset.range(100)
iterator=dataset.make_one_shot_iterator()
next_element=iterator.get_next()
with tf.Session() as sess:
    for i in range(100):
        value=sess.run(next_element)
        assert i==value

(2)initializer需要你运行一个iterator.initializer操作。当然你可以对你的数据集进行参数化,还可以使用多个tf.placeholder()可以在你初始化的时候传入。

max_value=tf.placeholder(tf.int64,shape=[])
dataset=tf.data.Dataset.range(max_value)
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer,feed_dict={max_value:10})
    for i in range(10):
        value=sess.run(next_element)
        assert i==value
    sess.run(iterator.initializer,feed_dict={max_value:100})
    for i in range(100):
        value=sess.run(next_element)
        assert i==value

(3)reinitializable:是一个可以从多个不同的数据集实体初始化的迭代器。例如你可能有一个训练数据的输入流,你会添加上一些扰乱从而来提高模型的泛化能力。然后你可能还有个交叉验证集的输入来在未修改的数据上进行评估结果。这些输入线使用的是不同的数据集,但是每一个组件是存在相同的类型和可兼容的shape的。

from __future__ import absolute_import,division,print_function
import tensorflow as tf

training_dataset=tf.data.Dataset.range(100).map(
    lambda x:x+tf.random_uniform([],-10,10,tf.int64)
)
validation_dataset=tf.data.Dataset.range(50)

#这里的reinitializable迭代器可以从训练的数据集,当然也可以从交叉验证的数据集上获取类型和shape,兼容性
iterator=tf.data.Iterator.from_structure(training_dataset.output_types,
                                         training_dataset.output_shapes)

next_element=iterator.get_next()

training_init_op=iterator.make_initializer(training_dataset)
validation_init_op=iterator.make_initializer(validation_dataset)

#运行20个epochs,贯穿每个数据集,然后跟随着交叉验证集。
with tf.Session() as sess:
    for _ in range(20):
        sess.run(training_init_op)
        for _ in range(100):
            sess.run(next_element)
        
        #z在交叉验证集上初始化一个迭代器
        sess.run(validation_init_op)
        for _ in range(50):
            sess.run(next_element)



(4)feedable:这个迭代器可以在tf.Session.run的阶段内通过tf.placeholder来选择哪一个迭代器能被使用。当然也是使用的feed_dict机制,这个迭代器跟reinitializable的迭代器有的相同的功能,但是他不需要你每次在数据集的开头都需要初始化你的迭代器。使用tf.data.Iterator.from_string_handle来定义一个feedable迭代器。

from __future__ import absolute_import,division,print_function
import tensorflow as tf

training_dataset=tf.data.Dataset.range(100).map(
    lambda x:x+tf.random_uniform([],-10,10,tf.int64)
).repeat()
validation_dataset=tf.data.Dataset.range(50)

#一个feedable迭代器基于handle placeholder来定义。当然还是可使用训练集或者交叉验证集的
#类型和大小,因为这两个数据集的结构一样。
handle=tf.placeholder(tf.string,shape=[])
iterator=tf.data.Iterator.from_string_handle(
    handle,training_dataset.output_types,training_dataset.output_shapes
)
next_element=iterator.get_next()

#你可以通过feedable迭代器来使用不同的的迭代器
training_iterator=training_dataset.make_one_shot_iterator()
validation_iterator=validation_dataset.make_initializable_iterator()

#Iterator.string_handle()方法返回的是一个张量,可以被evaluate,也可以被传入之前的handle占位符
with tf.Session() as sess:
    training_handle=sess.run(training_iterator.string_handle())
    validation_handle=sess.run(validation_iterator.string_handle())

    #在这两个数据集之间一直交换
    for i in range(2):
        #使用训练集跑200步,训练集是无限的,我们从while上次结束的地方继续。
        for _ in range(200):
            sess.run(next_element,feed_dict={handle:training_handle})

        #在交叉验证集上跑跑一次
        sess.run(validation_iterator.initializer)
        for _ in range(50):
            sess.run(next_element,feed_dict={handle:validation_handle})


1.3 从一次迭代中消耗值

tf.get_next()到达数据集的末尾的时候,会引发一个tf.errors.OutOfRangeError。如果你还要使用必须初始化。

代码:

from __future__ import absolute_import,division,print_function
import tensorflow as tf

dataset=tf.data.Dataset.range(5)
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()

result=tf.add(next_element,next_element)

with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(sess.run(result))
    print(sess.run(result))
    print(sess.run(result))
    print(sess.run(result))
    print(sess.run(result))
    try:
        sess.run(result)
    except tf.errors.OutOfRangeError:
        print("end of dataset")

结果:

0
2
4
6
8
end of dataset

很常见的方式是将我们的训练循环放在try-except的模块里面:

from __future__ import absolute_import,division,print_function
import tensorflow as tf

dataset=tf.data.Dataset.range(5)
iterator=dataset.make_initializable_iterator()
next_element=iterator.get_next()

result=tf.add(next_element,next_element)

with tf.Session() as sess:
    sess.run(iterator.initializer)
    while True:
        try:
            print(sess.run(result))
        except tf.errors.OutOfRangeError:
            break

如果我们的数据集存在嵌套的结构,Iterator.get_next()将得到一个或者多个张量的实体在一个相同的嵌套的结构里。

from __future__ import absolute_import,division,print_function
import tensorflow as tf

dataset1=tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,10]))
dataset2=tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]),
                                             tf.random_uniform([4,100])))
dataset3=tf.data.Dataset.zip((dataset1,dataset2))
iterator=dataset3.make_initializable_iterator()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    next1,(next2,next3)=iterator.get_next()
    print(next1,(next2,next3))

结果:

Tensor("IteratorGetNext:0", shape=(10,), dtype=float32) (<tf.Tensor 'IteratorGetNext:1' shape=() dtype=float32>, <tf.Tensor 'IteratorGetNext:2' shape=(100,) dtype=float32>)

1.4 保存迭代状态

从迭代器中利用tf.contrib.data.make_saveable_from_iterator函数创造一个SaveableObject,这个可以用来保存当前的迭代器的状态。一个可保存的实体可以添加到tf.train.Saver的变量表中,或者是tf.GraphKeys.SAVEABLE_OBJECTS集合来存储。

from __future__ import absolute_import,division,print_function
import tensorflow as tf

dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

result = tf.add(next_element, next_element)
#从迭代器上创造一个可保存的实体

saveable = tf.contrib.data.make_saveable_from_iterator(iterator)

#保存迭代器的状态,通过将其存储在一个可保存的实体集合中
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
should_checkpoint=True
path_to_checkpoint='./cp.ckpt'
with tf.Session() as sess:
    sess.run(iterator.initializer)
    try:
        sess.run(result)
        if should_checkpoint:
            saver.save(sess,save_path=path_to_checkpoint)
    except tf.errors.OutOfRangeError:
        print("end of dataset")

# Restore the iterator state.
with tf.Session() as sess:
  saver.restore(sess, path_to_checkpoint)

二、读输入数据

2.1 消耗numpy数组

如果你的数据存放在内存中,最简单的方式去创造一个数据集就是将其转化成tf.Tensor并且使用Dataset.from_tensor_slices(),代码如下:

with np.load('/var/data/training_data.npy') as data:
    features=data['features']
    labels=data['labels']
    
assert features.shape[0]==labels.shape[0]
dataset=tf.data.Dataset.from_tensor_slices((features,labels))

但是上面的代码将你的feature和labels数组转成TensorFlow计算图里面的tf.constant()操作。虽然数据集很小但是还是很浪费内存。所以做下面的改变,将你的数据集转变成tf.placeholder(),然后等你初始化你的数据集上的迭代器的时候,才将numpy的数组传入。

with np.load('/var/data/training_data.npy') as data:
    features=data['features']
    labels=data['labels']
    
assert features.shape[0]==labels.shape[0]

features_placeholder=tf.placeholder(features.dtype,features.shape)
labels_placeholder=tf.placeholder(labels.dtype,labels.shape)

dataset=tf.data.Dataset.from_tensor_slices((features,labels))

iterator=dataset.make_initializable_iterator()

with tf.Sesssion() as sess:
    sess.run(iterator.initializer,feed_dict={features_placeholder:features,
                                             labels_placeholder:labels})

2.2 消耗TFRecord数据

tf.data API支持各种形式的文件类型,当你要处理大规模的数据的时候,不适合存放在内存。例如,TFRecord文件形式是一个简单的以记录为导向的二进制文本。tf.data.TFRecordDataset类确保了你可以在多个或者单个TFRecord文件上进行数据的流动。

代码:

#创建一个数据集,从下面的两个文件中读取数据
filenames=["/var/data/file1.tfrecord","/var/data/file2.tfrecord"]
#filenames这个参数可以是一个字符串或者是字符串的列表,或者是一个 字符串的张量。
dataset=tf.data.TFRecordDataset(filenames)
#filenames这个参数可以是一个字符串或者是字符串的列表,或者是一个 字符串的张量。
#所以当你存在一个为了训练一个为了交叉验证的数据集时,你可以使用tf.placeholder(tf.string)
#来代表你的文件名。然后在初始化迭代器
filenames=tf.placeholder(tf.string,shape=[None])
dataset=tf.data.TFRecordDataset(filenames)

dataset=dataset.map(...)#将记录换成张量的操作
dataset=dataset.repeat()#无穷尽的复制你的输入数据
dataset=dataset.batch(32)
iterator=dataset.make_initializable_iterator()

#你可以为当前的执行阶段传入合适的文件名,要么是为了训练,要么是为了交叉验证集
training_filenames=['/var/data/file1.tfrecord','/var/data/file2.tfrecord']

validation_filenames=['/var/data/validaton1.tfrecord']
with tf.Session() as sess:
    sess.run(iterator.initializer,feed_dict={filenames:training_filenames})
    sess.run(iterator.initializer,feed_dict={filenames:validation_filenames})

2.3 消耗text数据

很多的数据集分布在多个text文件。tf.data.TextLineDataset提供了一个很简单的方式从多个文本文件中读取行数据。他跟TFRecordDataset一样可以接受文件名作为tf.Tensor,所以你可以将其作为参数传给tf.placeholder(tf.string)

代码:

filename=["/var/data/file1.txt","/var/data/file2.txt"]
dataset=tf.data.TextLineDataset(filename)

默认来说TextLineDataset取得每个文件的每一行。但是这样是不行的,比如文件开头存在文件头,评论等等,是需要去除的。使用Dataset.skip()和Dataset.filter()来进行这一功能的实现。为了对于每个文件分开实现这样的功能,我们使用Dataset.flat_map()来为每个文件创建一个数据集。

代码:


filenames=["/var/data/file1.txt","/var/data/file2.txt"]
dataset=tf.data.Dataset.from_tensor_slices(filenames)

#使用Dataset.flat_map()来将每个文件转成一个单独的嵌入式数据集,然后将他们的数据线性的连接到一个flat数据集。
#跳过第一行
#过滤掉#开头的行
#使用substr来获取每一行的第一个字符(输入,pos,len),然后与#对比
dataset=dataset.flat_map(
    tf.data.TextLineDataset(filenames).skip(1).filter(
        lambda line:tf.not_equal(tf.substr(line,0,1),'#')
    )
)

2.4 消耗csv数据

csv文件类型是最受欢迎的存储表格形式的数据。tf.contrib.data.CsvDataset类提供了从一个或者多个csv文件中读取数据的方式。跟之前一样可以使用tf.placeholder(tf.string).

代码:

# Creates a dataset that reads all of the records from two CSV files, each with
# eight float columns
filenames = ["/var/data/file1.csv", "/var/data/file2.csv"]
record_defaults = [tf.float32] * 8   # Eight required float columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)

如果一些列是空的,你可以提供默认值,不用提供类型。

代码:

# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values
record_defaults = [[0.0]] * 8
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)

默认,CsvDataset读取文件的每一行每一列,但是这样并不好,假如你的数据首行不要,一些列也不需要。这些不要的可以使用header和select_cols来去除。

代码:只要文件的第二列和第四列

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [[0.0]] * 2  # Only provide defaults for the selected columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults, header=True, select_cols=[2,4])

三、使用Dataset.map()预处理数据

Dataset.map(f)通过函数f改变数据集中的每一个元素f。

3.1 使用tf.Example协议内存信息

TFRecords文件包含了tf.train.Example 协议内存块(协议内存块包含了特征 Features)。我们可以写一段代码获取你的数据,将数据填入到Example协议内存块,将协议内存块序列化为一个字符串,并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。从TFRecords文件中读取数据, 可以使用tf.TFRecordReadertf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

def _parse_function(example_proto):
    features={"image":tf.FixedLenFeature((),tf.string,default_value=""),
              "label":tf.FixedLenFeature((),tf.int32,default_value=0)}
    parsed_features=tf.parse_single_example(example_proto,features)
    return parsed_features["image"],parsed_features["label"]

filename=['/var/data/file1.tfrecord','/var/data/file2.tfrecord']
dataset=tf.data.TFRecordDataset(filename)
dataset=dataset.map(_parse_function)

3.2 解码图像数据并重新规划大小

def _parse_function(filename,label):
    image_string=tf.read_file(filename)
    image_decoded=tf.image.decode_jpeg(image_string)
    image_resized=tf.image.resize_images(image_decoded,[28,28])
    return image_resized,label

filename=tf.constant(["/var/data/image1.jpg","var/data/image2.jpg"])
labels=tf.constant([0,37])
dataset=tf.data.Dataset.from_tensor_slices((filename,labels))
dataset=dataset.map(_parse_function())

3.3 通过tf.py_func()随意的使用python逻辑

当TensorFlow对数据的操作已经无法满足您的需求的时候,Dataset.map()里面的tf.py_func()操作可以帮您

import cv2
#使用opencv读取图像的方式,而不是使用标准的TensorFlow的tf.read_file()
def _read_py_function(filename,label):
    image_decoded=cv2.imread(filename.decode(),cv2.IMREAD_GRAYSCALE)
    return image_decoded,label

#使用标准的TensorFlow操作来给图像重新规划大小
def _resize_function(image_decoded,label):
    image_decoded.set_shape([None,None,None])
    image_resized=tf.image.resize_images(image_decoded,[28,28])
    return image_resized,label

filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]

dataset=tf.data.Dataset.from_tensor_slices((filenames,labels))
dataset=dataset.map(
    lambda filename,label:tuple(tf.py_func(
        _read_py_function,[filename,label],[tf.unit8,label.dtype]
    ))
)
dataset=dataset.map(_resize_function)


四、分批数据集元素

4.1 简单的分批

最简单分批次的形式是将数据集的前n个连续的元素放入一个元素当中。Dataset.batch()就是这样的,跟tf.stack()具有相同的限制。


inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()

# ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])
# ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
# ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])

with tf.Session() as sess:
    print(sess.run(next_element))
    print(sess.run(next_element))
    print(sess.run(next_element))

4.2 通过padding来分批次tensor

前面的例子的张量大小一样。很多的模型的输入数据的大小是不一致的,所以为了处理这样的数据,Dataset.padded_batch()使你可以将大小不一样的张量,指定维度来填充。

dataset=tf.data.Dataset.range(100)
dataset=dataset.map(lambda x:tf.fill([tf.cast(x,tf.int32)],x))
#【】表示的是shape,x表示数字,1,2,2,,3,3,3这样的
dataset=dataset.padded_batch(4,padded_shapes=[None])

iterator=dataset.make_one_shot_iterator()
next_element=iterator.get_next()

with tf.Session() as sess:
    print(sess.run(next_element))
    print(sess.run(next_element))


# [[0 0 0]
#  [1 0 0]
#  [2 2 0]
#  [3 3 3]]
# [[4 4 4 4 0 0 0]
#  [5 5 5 5 5 0 0]
#  [6 6 6 6 6 6 0]
#  [7 7 7 7 7 7 7]]

五、训练工作流

5.1 处理多个迭代周期

第一种方式是将Dataset.repeat()重复多次来实现跑10个epochs:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)

如果repeat没有参数,那么就重复无限次的重复数据。如果你想捕捉一个epoch的end,使用tf.errors.OutOfRangeError。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

# Compute for 100 epochs.
for _ in range(100):
  sess.run(iterator.initializer)
  while True:
    try:
      sess.run(next_element)
    except tf.errors.OutOfRangeError:
      break

  # [Perform end-of-epoch calculations here.]

5.2 随机的打乱数据

Dataset.shuffle()随机的打乱输入的数据,tf.RandomShuffleQueue()保持一个固定size的空间,随机的选择下一个元素。

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()

5.3 使用高级接口

tf.train.MonitoredTrainingSession() API简化了TensorFlow在分布式上运行的各个方面。它使用tf.errors.OutOfRangeError来表明训练完成。与tf.data接口使用,我们推荐使用Dataset.make_one_shot_iterator().

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()

next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)

training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
  while not sess.should_stop():
    sess.run(training_op)

input_fn函数里面使用一个数据集,tf.estimator.Estimator,我们建议使用Dataset.make_one_shot_iterator():

def dataset_input_fn():
  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
  dataset = tf.data.TFRecordDataset(filenames)

  # Use `tf.parse_single_example()` to extract data from a `tf.Example`
  # protocol buffer, and perform any additional per-record preprocessing.
  def parser(record):
    keys_to_features = {
        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Perform additional preprocessing on the parsed data.
    image = tf.image.decode_jpeg(parsed["image_data"])
    image = tf.reshape(image, [299, 299, 1])
    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  # Use `Dataset.map()` to build a pair of a feature dictionary and a label
  # tensor for each example.
  dataset = dataset.map(parser)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(32)
  dataset = dataset.repeat(num_epochs)
  iterator = dataset.make_one_shot_iterator()

  # `features` is a dictionary in which each value is a batch of values for
  # that feature; `labels` is a batch of labels.
  features, labels = iterator.get_next()
  return features, labels

猜你喜欢

转载自blog.csdn.net/m0_37393514/article/details/81180456