Tensorflow(四)- CNN_CIFAR(一)- cifar10_input

这篇开始,讲述自己对于tensorflow文档中,利用CNN建立CIFAR-10模型的理解,如有错误欢迎指正,也是互相学习。由于代码太长,所以分几篇来讲述。
第一篇是关于cifar10_input.py文件。

cifar10_input.py

这个文件主要用来进行数据读取以及输入数据处理。
Tensorflow一共有三种读取数据的方式:
第一种最简单的预加载数据,直接在graph中定义常量和变量来保存数据(仅适用于数据小的情况,神经网络自然这种方法不行)
第二种供给数据(feeding),那么这种方法前面出现过很多次,也就是在图中建立placeholder,然后在跑图的时候再对占位进行数据feed。(如果数据量过大,一次性读入所有数据,再分批次feed进图,也会占用太多的内存空间)
第三种从文件中读取数据,在Tensorflow的起始,用一个输入管线从文件中读取数据。
那么这个模型就是运用了第三种数据读取方式。
这里我们按照实际建立图的顺序来进行讲解,由cifar10_train.py文件我们知道先调用的是distorted_input函数。

distorted_input

函数输入为data_dir(训练集所在文件夹),以及batch_size。
函数输出为图像组成的4维tensor,以及labels组成的1维tensor。
下面为数据读取步骤:
1. 生成文件名列表,也就是把需要输入的所有文件的文件名放到一个列表里。
2. 将文件名列表输入到tf.train.string_input_producer()函数中,生成一个先入先出的文件名队列,同时将一个QueueRunner添加到整个图的QUEUE_RUNNER当中。(tf.train.QueueRunner本质上是tensorflow的一个类,用来完成队列的一系列入队操作)
下面为API文档,具体请参考文档。
https://www.tensorflow.org/api_docs/python/tf/train/string_input_producer
3. 利用文件阅读器读取文件名队列中文件里的数据。接下来我们跟随代码转移到read_cifar10函数中。
6. 从read_cifar10函数回来,这个时候read_input已经是刚刚输出的结构体了,它有一个样本的所有信息。然后对图像数据进行一系列预处理,这是data augmentation,包括随机裁取,随机左右翻转,随机亮度调节,随机对比度调节,图像归一化。
所有关于图像操作的API文档
https://www.tensorflow.org/api_guides/python/image
7. 定义参数min_queue_examples,然后进入到generate_image_and_label_batch()函数中。

def distorted_inputs(data_dir, batch_size):
    # 生成文件名列表
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
                 for i in range(1, 6)]
    for f in filenames:
        if not gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    # 生成文件名队列
    filename_queue = tf.train.string_input_producer(filenames)
    # 从文件名队列中的文件中读取出样本数据
    read_input = read_cifar10(filename_queue)
    # 图像数据类型转换
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)
    height = IMAGE_SIZE
    width = IMAGE_SIZE

    # 对图像进行预处理,data augmentation
    distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)
    distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)
    float_image = tf.image.per_image_standardization(distorted_image)

  # 定义min_queue_examples
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *min_fraction_of_examples_in_queue)
    print ('Filling queue with %d CIFAR images before starting to train. '
           'This will take a few minutes.' % min_queue_examples)

    return generate_image_and_label_batch(float_image, read_input.label,min_queue_examples, batch_size)

read_cifar10

函数输入,filename_queue(文件名队列)。
函数输出为一个代表一个单独数据样本的结构体。这个结构体包含图像样本的height,width,depth,key(用来表征这个样本的输入文件以及其record number),label(它的类别)以及图像样本本身(uint8数据类型的3维tensor)
3. 承接着上面的第三步,将文件名队列输入给read_cifar10()函数,根据要读取的文件的格式以及文件本身记录的属性,建立相应的文件阅读器,这里我们需要读取的是二进制文件而且cifar10数据集每条记录的长度是固定的,所以对应我们建立一个固定长度记录阅读器,也就是按固定字节进行文件阅读。
tf.FixedLengthRecordReader(),下面为API文档
https://www.tensorflow.org/api_docs/python/tf/FixedLengthRecordReader
4. 利用上面建立的阅读器的read()操作读取出(key,value)的tensor元组。key就是上面输出中提到的key,value自然是样本数据,然后利用与tf.FixedLengthRecordReader配套的tf.decode_raw()操作来将字符串转换为数字张量。(数字的格式可以利用out_type进行设定,默认是uint8)
tf.decode_raw() API文档
https://www.tensorflow.org/api_docs/python/tf/decode_raw
在这里需要说明每次read()的执行,只从文件中读取出一个样本。
5. 后续进行一个切片以及类型转换的处理之后,完成结构体的输出。然后回到distorted_input()函数里。

def read_cifar10(filename_queue):
    # 建立一个空类,方便数据的结构化存储
    class CIFAR10Record(object):
        pass
    result = CIFAR10Record()
    # 确定一个样本的字节数
    label_bytes = 1 
    result.height = 32
    result.width = 32
    result.depth = 3
    image_bytes = result.height * result.width * result.depth
    record_bytes = label_bytes + image_bytes
    # 建立reader,从二进制文件中读取固定长度记录
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    # read 返回tuple of tensors (key, value), 返回由reader生成的下一条记录(key, value)对
    result.key, value = reader.read(filename_queue)
    # 将字符流的数据翻译成数字张量
    record_bytes = tf.decode_raw(value, tf.uint8)

    # 单独提出样本中的label并进行类型转换uint8转为int32
    result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
    # 提出样本中的数据reshape成图像格式
    depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
                             [result.depth, result.height, result.width])
    # 转换[depth, height, width]到[height, width, depth].
    result.uint8image = tf.transpose(depth_major, [1, 2, 0])
    return result

generate_image_and_label_batch

函数输入为,image(3维的图像数据tensor),label(图像标签),min_queue_examples(保留在样本队列中的最小样本数),batch_size。
函数输出带有batch_size的4维图像tensor以及带有batch_size的2维label。
7. 紧接着上面的第7步进入到generate_image_and_label_batch()函数中,该函数其实就进行了一个函数操作 tf.train.shuffle_batch()。下面是API文档
https://www.tensorflow.org/api_docs/python/tf/train/shuffle_batch
该函数将下面三个东西加入到当前的计算图中:

  • 一个带有固定大小(capacity)的随机(shuffling)样本队列。
  • 一个dequeue_many操作,用来从样本队列中完成多样本出队,也就是创建带有batch的样本然后输出。
  • 将一个QueueRunner加入到图的QUEUE_RUNNER集合中,用来进行样本入队。

这里对min_after_dequeue参数进行说明,它表示在队列进行出队操作后,队列所需要保留的最小样本数,也就是当出队后,队列样本数不够该参数就要进行样本填充。这是为了保证shuffle的效果。
还有num_threads是线程数量,通过设定多线程可以快速完成入队操作,相当于同时读取一个文件的多个样本,然后入队。
8. 完成images和labels的输出。

def generate_image_and_label_batch(image, label, min_queue_examples,
                                    batch_size):

    num_preprocess_threads = 16
    images, label_batch = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size,
        min_after_dequeue=min_queue_examples)

    tf.summary.image('images', images)
    return images, tf.reshape(label_batch, [batch_size])

总结

实际上,从文件中读取数据,我们建立了两个队列,一个是文件名队列,一个是样本队列,其大致过程就是利用多线程从文件名队列中读取样本到样本队列中,然后最后输出带有batch的数据。
下面有几个小要点。
1. 在我们建立好计算图,完成初始化后,在跑run和eval之前,我们需要运行tf.train.start_queue_runners(),来启动我们建立的这个输入管线,启动计算图的QUEUE_RUNNER中所有的QueueRunner,开始对队列进行入队操作。
2. 跑过程序的知道,distorted_inputs里面有一句print,这句话只显示了一次,说明本身程序只跑了一次来建立输入管线,输入管线建立好以后,后面无数次的入队出队都是在后台异步运行。
3. 样本队列出队是因为shuffle_batch函数将一个dequeue_many操作加入到了计算图,而我们好像并没有看到对于文件名队列的出队操作。其实文件名队列的出队操作是由文件阅读器的read函数完成了。
4. shuffle_batch第一次需要队列填充满,才开始进行出队,这也是为什么程序跑第一个epoch很慢,因为大量的时间用来完成第一次的队列填充。

下面链接为本次博文参考
http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html


小广告

淘宝choker、耳饰小店 物理禁止
女程序员编码时和编码之余 都需要一些美美的choker、耳饰来装扮自己
男程序员更是需要常备一些来送给自己心仪的人
淘宝小店开店不易 希望有缘人多多支持 (O ^ ~ ^ O)
本号是本人 只是发则小广告 没有被盗 会持续更新深度学习相关博文和一些翻译
感谢大家 不要拉黑我 ⊙﹏⊙|||°
这里写图片描述

猜你喜欢

转载自blog.csdn.net/mike112223/article/details/78539917
今日推荐