Tensorflow读取csv文件

前言:当深度学习使用的训练数据文件过大,使用pandas读取时会一次性读取全部数据,给内存带来了极大的压力。Tensorflow提供了一个使用队列且多线程读取文件的机制,缓解了内存的压力。该程序完整代码:https://github.com/iapcoder/TensorflowReadCSV


一 步骤:

1、构造文件队列

file_queue = tf.train.string_input_producer(file_list) # file_list: csv文件路径列表

2、构造csv阅读器读取队列数据(读取的是一行)

reader = tf.TextLineReader(skip_header_lines=1) # skip_header_lines 指定跳过几行
key, value = reader.read(file_queue) # key:行号 value: 内容

 3、对每行的内容进行解码

records = [[1.0], [1]] # 指定每一列的类型,1.0表示是浮点型,缺失则为1.0, 1表示整型,缺失则为1,“None”表示字符串,缺失则为None
example, label = tf.decode_csv(value, record_defaults=records) # 返回的是一行的数据

4、若想读取多个数据,需要使用批处理

example_batch, label_batch = tf.train.batch([example, label], batch_size=9, num_threads=2, capacity=9) # batch_size:要读取多少行 num_threads:指定多少个子线程 capacity:指定队列容量

5、开启会话获取数据

with tf.Session() as sess:

    coord = tf.train.coordinator() # 定义一个线程协调器
    threads = tf.train.start_queue_runner(sess, coord=coord) # 开启读取文件的线程
    data = sess.run([example_batch, label_batch]) # 获取数据
    coord.request_stop() # 请求关闭线程
    coord.join(threads)  # 主线程等待子线程结束

二 实例

假设有三个csv文件,如下图所示,第一行为表头。

构造三个文件路径的列表,传入给文件队列

file_name = os.listdir("../datas/")
file_list = [os.path.join("../datas/", file) for file in file_name]

经过Tensorflow读取文本数据

 

猜你喜欢

转载自blog.csdn.net/qq_41689620/article/details/88769582