Deep learning和tensorflow学习记录(三十一):Feed data

一、Constants

最简单的方法。

import tensorflow as tf
import numpy as np

actual_data = np.random.normal(size=[100])

data = tf.constant(actual_data)

这种方式效率很高,但是当需要应用于其他数据时,必须重写。并且这种方式一次性将所有数据加载到内存,仅适用小数据集。

二、Placeholders

import tensorflow as tf
import numpy as np

data = tf.placeholder(tf.float32)

prediction = tf.square(data) + 1

actual_data = np.random.normal(size=[100])

tf.Session().run(prediction, feed_dict={data: actual_data})

placeholders 是在session run中通过feed_dict来feed数据。

三、Python ops

def py_input_fn():
    actual_data = np.random.normal(size=[100])
    return actual_data

data = tf.py_func(py_input_fn, [], (tf.float32))

四、Dataset API

tensorflow中推荐使用dataset API来feed 数据。

actual_data = np.random.normal(size=[100])
dataset = tf.contrib.data.Dataset.from_tensor_slices(actual_data)
data = dataset.make_one_shot_iterator().get_next()
dataset = dataset.cache()
if mode == tf.estimator.ModeKeys.TRAIN:
    dataset = dataset.repeat()
    dataset = dataset.shuffle(batch_size * 5)
dataset = dataset.map(parse, num_threads=8)
dataset = dataset.batch(batch_size)

如果是从文件中获取数据,将文件数据转化成TFrecord格式再用TFRecordDataset读取效率更高。

dataset = tf.contrib.data.TFRecordDataset(path_to_data)

猜你喜欢

转载自blog.csdn.net/heiheiya/article/details/81093126
今日推荐