使用猫狗大战数据集进行一次完整的TensorFlow训练

1.简介

一直想将图片制作成tfrecords文件,然后在模型中运行一下。最初想用的数据集是mnist,但是跑的过程中一直出现问题。找到这一篇知乎上的博客,写的非常不错。

原博客地址:https://zhuanlan.zhihu.com/p/32490882

其代码地址:https://github.com/HelloSangShen/Cat-vs-Dog/

猫狗数据集:https://pan.baidu.com/s/13hw4LK8ihR6-6-8mpjLKDA   密码:dmp4


2. 本文结构

本文以kaggle的猫狗大战为例,完整地描述使用TensorFlow进行一次完整CNN训练的每个步骤。首先介绍如何将图片转为TFRecords文件,然后介绍如何读取该文件的数据并且输入给我们的网络进行训练,并且会展示如何通过hook来监测网络训练的情况(这里没有使用TensorBoard)。最后会简单解读一下MonitoredTrainingSession的使用方法。

3. 正文

3.1 数据处理

有过实践的小伙伴应该能感受到,当有了TensorFlow、PyTorch这样优秀的框架后,构造一个神经网络、进行训练、计算损失函数、预测等都变的相对容易许多。但是数据的预处理仍然是一个相对棘手的问题,尤其是在较大数据集上进行训练时,不能总是使用占位符(placeholder)和feed dict进行数据加载,而TensorFlow提供了另外一种加载方式。这部分就着重介绍如何将图片数据存储为TFRecords,并且通过队列读取给我们的网络。因为网上有非常多介绍TFRecords原理的文章,我这里就不细说了,只给出详细的代码和注释,示范一下如何处理。

def read_images(path):
    """从源文件/路径读取图像

    参数:
        path: 图像所在的路径即文件夹名称
    返回:
        返回一个带有所有图像、标签和总数信息的对象
        images: 所有的图像数据
        labels: 所有标签
        num: 数目
    """

    # 获取文件夹内所有图像文件的文件名和总数
    filenames = next(walk(path))[2]
    num_file = len(filenames)

    # 初始化图像和标签
    images = np.zeros((num_file, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL), dtype=np.uint8)
    labels = np.zeros((num_file, ), dtype=np.uint8)

    # 遍历读取文件
    for index, filename in enumerate(filenames):
        # 读取单张图像,并且修改为自定义尺寸
        img = imread(join(path, filename))
        img = imresize(img, (IMAGE_HEIGHT, IMAGE_HEIGHT))
        images[index] = img

        # TO DO
        # 这里通过文件名获取标签信息,猫狗大战问题中只有两类,故只有0和1
        # 可以根据自己的需要进行修改
        # 注意:这里不是one-hot编码
        if filename[0:3] == 'cat':
            labels[index] = int(0)
        else:
            labels[index] = int(1)

        if index % 1000 == 0:
            print("Reading the %sth image" % index)
    
    # 创建一个类,该类携带图像、标签和总数信息
    class ImgData(object):
        pass

    result = ImgData()
    result.images = images
    result.labels = labels
    result.num = num_file

    return result

通过上述函数,我们可以读取到文件夹内所有的图片。接下来,我们要把这些图片转为TFRecords文件。

def convert(data, destination):
    """将图片存储为.tfrecords文件

    参数:
        data: 上述函数返回的ImageData对象
        destination: 目标文件名
    """

    images = data.images
    labels = data.labels
    num_examples = data.num

    # 存储的文件名
    filename = destination
    
    # 使用TFRecordWriter来写入数据
    writer = tf.python_io.TFRecordWriter(filename)
    # 遍历图片
    for index in range(num_examples):
        # 转为二进制
        image = images[index].tostring()
        label = labels[index]
        # tf.train下有Feature和Features,需要注意其区别
        # 层级关系为Example->Features->Feature
        example = tf.train.Example(features=tf.train.Features(feature={
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
        }))
        # 写入
        writer.write(example.SerializeToString())
    writer.close()

这两个函数就可以把我们的数据集图片全都写入一个.tfrecords文件。如果文件过大,可以写入多个文件。

下面介绍如何从tfrecords文件中批量读取图片和标签。

def read_and_decode(filename_queue):
    """读取.tfrecords文件

    参数:
        filename_queue: 文件名, 一个列表

    返回:
        img, label: **单张图片和对应标签**
    """
    # 创建一个图节点,该节点负责数据输入
    filename_queue = tf.train.string_input_producer([filename_queue])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    
    # 解析单个example
    features = tf.parse_single_example(serialized_example, features={
        'image': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    })

    image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    image = tf.cast(image, tf.float32)
    label = tf.cast(features['label'], tf.int64)

    return image, label

我们将数据读取的功能进行封装,代码如下:

def distorted_input(filename, batch_size):
    """建立一个乱序的输入
    
    参数:
      filename: tfrecords文件的文件名. 注:该文件名仅为文件的名称,不包含路径和后缀
      batch_size: 每次读取的batch size
      
    返回:
      images: 一个4D的Tensor. size: [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3]
      labels: 1D的标签. size: [batch_size]
    """
    # 完整文件名,文件存储在同一路径下的tfrecords文件夹下,名为filename.tfrecords
    filename = './tfrecords/' + filename + '.tfrecords'
    
    # 如果路径下没有该文件,说明没有进行转换工作,则将图片转为tfrecords文件
    if not os.path.exists(filename):
        print('Transfer images to TF_Records')
        raw_data = read_images(FLAGS.raw_data_path)
        convert(raw_data, filename)
        print('End transfering')

    image, label = read_and_decode(filename)
    # 乱序读入一个batch
    images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size,
                                    num_threads=16, capacity=3000, min_after_dequeue=1000)

    return images, labels

以上,我们就完成了数据的读取部分了。下面用一段代码进行测试。

images, labels = catdog_input.distorted_input(FLAGS.tfrecords_file_name, batch_size=4)

    # from matplotlib import pyplot as plt
    fig = plt.figure()
    a = fig.add_subplot(221)
    b = fig.add_subplot(222)
    c = fig.add_subplot(223)
    d = fig.add_subplot(224)

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        coord = tf.train.Coordinator()
        # 开启文件读取队列,开启后才能开始读取数据
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        
        img, label = sess.run([images, labels])
        
        a.imshow(img[0])
        a.axis('off')

        b.imshow(img[1])
        b.axis('off')

        c.imshow(img[2])
        c.axis('off')

        d.imshow(img[3])
        d.axis('off')

        plt.show()

        coord.request_stop()
        coord.join(threads)

通过这个简单的测试程序就可以可视化四张图片出来。

3.2 模型

这里,我们使用VGG-16模型来做测试。TensorFlow在搭建网络上非常方便,这里就不给详细代码了(可以参考cs.toronto.edu/~frossar),读者可以在文末的GitHub链接上找到相关代码。

对于准确率、损失函数等,我们参考TensorFlow教程中Cifar10训练的源代码进行实现,将这些函数均封装起来。

def loss(logits, labels):

    labels = tf.cast(labels, tf.int64)
    # 注意:我们上面定义的标签不是one-hot编码,故这里调用的是sparse方法
    # 如果使用one-hot,调用softmax_cross_entropy_with_logits即可
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits, name='cross_entropy_per_example')
    loss = tf.reduce_mean(cross_entropy, name='cross_entropy')
    return loss

  
def accuracy(logits, labels):
    
    # 将labels转为one-hot编码进行计算
    labels = tf.one_hot(labels, NUM_CLASS)
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    return accuracy

  
def train(loss):
    
    train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
    return train_op

至此,我们的模型就搭建好了。接下来就是训练步骤。

3.3 训练

下面的train()函数也是参照cifar10的源码进行实现的。

def train():

    # 因为要使用StopAtStepHook,故global_step是必须的
    global_step = tf.train.get_or_create_global_step()

    # 输入
    images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE)
    
    logits = catdog_model.inference(images)
    loss = catdog_model.loss(logits, labels)
    # accuracy = catdog_model.accuracy(logits, labels)

    train_op = catdog_model.train(loss)

    class _LoggerHook(tf.train.SessionRunHook):
        """ 
        该类用来打印训练信息
        """
        def begin(self):
            self._step = -1
            self._start_time = time.time()

        def before_run(self, run_context):
            self._step += 1
            # 该函数在训练运行之前自动调用
            # 在这里返回所有你想在运行过程中查看到的信息
            # 以list的形式传递,如:[loss, accuracy]
            return tf.train.SessionRunArgs(loss)

        def after_run(self, run_context, run_values):

            # 打印信息的步骤间隔
            display_step = 10
            if self._step % display_step == 0:
                current_time = time.time()
                duration = current_time - self._start_time
                self._start_time = current_time
                # results返回的就是上面before_run()的返回结果,上面是loss故这里是loss
                # 若输入的是list,返回也是一个list
                loss = run_values.results

                # 每秒使用的样本数
                examples_per_sec = display_step * BATCH_SIZE / duration
                # 每batch使用的时间
                sec_per_batch = float(duration / display_step)
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), self._step, loss,
                                    examples_per_sec, sec_per_batch))

                
    with tf.train.MonitoredTrainingSession(
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],  # 将上面定义的_LoggerHook传入
            config=tf.ConfigProto(
                log_device_placement=False)) as sess:

        coord = tf.train.Coordinator()
        # 开启文件读取队列
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        while not sess.should_stop():
            sess.run(train_op)

        coord.request_stop()
        coord.join(threads)

上面就是在猫狗大战数据集上进行的一个完整的图片数据预处理、数据读取、搭建网络、训练并监测的过程。

3.4 评估

因为实验室设备暂时有点问题,没法训练,故现在没法给出结果,以后训练出结果后再来更新吧。

3.5 关于MonitoredTrainingSession

我们在上面的训练中用到了tf.train.MonitoredTrainingSession(...)。查阅了一下官方文档,该类继承自MonitoredSession类。我们先看看这个父类,官方文档中给了一段如下示例代码 :

saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with MonitoredSession(session_creator=ChiefSessionCreator(...),
                      hooks=[saver_hook, summary_hook]) as sess:
    while not sess.should_stop():
        sess.run(train_op)

首先,当MonitoredSession初始化的时候,会按顺序执行下面操作:

  • 调用hook的begin()函数,我们一般在这里进行一些hook内的初始化。比如在上面猫狗大战中的_LoggerHook里面的_step属性,就是用来记录执行步骤的,但是该参数只在本类中起作用。
  • 通过调用scaffold.finalize()初始化计算图
  • 创建会话
  • 通过初始化Scaffold提供的操作(op)来初始化模型
  • 如果checkpoint存在的话,restore模型的参数
  • launches queue runners
  • 调用hook.after_create_session()

然后,当run()函数运行的时候,按顺序执行下列操作:

  • 调用hook.before_run()
  • 调用TensorFlow的 session.run()
  • 调用hook.after_run()
  • 返回用户需要的session.run()的结果
  • 如果发生了AbortedError或者UnavailableError,则在再次执行run()之前恢复或者重新初始化会话

最后,当调用close()退出时,按顺序执行下列操作:

  • 调用hook.end()
  • 关闭队列和会话
  • 阻止OutOfRange错误

需要注意的是:该类不是一个tf.Session() ,因为它不能被设置为默认会话,不能被传递给saver.save,也不能被传递给tf.train.start_queue_runners,这也解释了为什么在开启会话后我们必须手动调用tf.train.start_queue_runners()

MonitoredTrainingSession则比起父类多了许多其他的参数,可以在官方文档获取各参数的说明,这里我们不详细说。但是根据其父类的执行说明,我们就可以很容易理解上面train()函数中发生了什么。

首先,我们先将计算图的各个节点/操作定义好,构成了一个计算图。然后开启了一个MonitoredTrainingSession来初始化/注册我们的图和其他信息。其中,我们给其传递了3个hook:

  • tf.train.StopAtStepHook(last_step),该hook主要是在训练到特定步数后即请求停止,使用该hook必须要预先定义一个tf.train.get_or_create_global_step()。否则会抛出运行时错误,见源码:

def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use StopAtStepHook.")

  • tf.train.NanTensorHook(loss),该hook用来监测loss,若loss的结果为NaN,抛出异常或者直接停止训练。
  • _LoggerHook(),该hook是我们自定义的hook,用来监测我们希望在训练过程中能查看的一些数据如loss或者accuracy。首先会随着MonitoredTrainingSession的初始化来调用begin()函数,我们在这里初始化步数,before_run()函数会随着sess.run()的调用而调用。故每训练一步调用一次,这里返回想要打印的信息,随后就调用after_run()函数,在这里,我们就将需要查看的信息打印出来即可。

随后,我们开启文件读取队列进行数据的输入。然后就一直调用sess.run()训练直到停下。

4.如何运行

首先得生成tfrecords文件,在当前文件夹下新建一个create_tfrecords.py,然后将下面的代码放进去(其实就是上面的代码)

import tensorflow as tf
import numpy as np
import os

from scipy.misc import imread,imresize
from os.path import join
from os import walk

IMAGE_WIDTH = 224
IMAGE_HEIGHT = 224
IMAGE_CHANNEL = 3
NUM_CLASS = 2

def read_images(path):
    """Read image from source file/directory

    Args:
        path: source derectory
    Return:
        An object representing all images and labels, fields:
        images: all image data
        labels: all labels
        num: number of images
    """

    # Get a list filenames
    filenames = next(walk(path))[2]
    num_file = len(filenames)

    # Initialize images and labels.
    images = np.zeros((num_file, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL), dtype=np.uint8)
    labels = np.zeros((num_file, ), dtype=np.uint8)

    # Iterate/Read all files
    for index, filename in enumerate(filenames):
        # Read single image and resize it to your expected size
        img = imread(join(path, filename))
        img = imresize(img, (IMAGE_HEIGHT, IMAGE_HEIGHT))
        images[index] = img

        # TO DO:
        if filename[0:3] == 'cat':
            labels[index] = int(0)
        else:
            labels[index] = int(1)

        if index % 1000 == 0:
            print("Reading the %sth image" % index)

    class ImgData(object):
        pass

    result = ImgData()
    result.images = images
    result.labels = labels
    result.num = num_file

    return result


def convert(data, destination):
    """Convert images to tfrecords

    Args:
        data: an object of ImgData, consisting of images, labels and number of images
        destination: destination filename of tfrecords
    """

    images = data.images
    labels = data.labels
    num_examples = data.num

    # filenale of tfrecords
    filename = destination

    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        image = images[index].tostring()
        label = labels[index]

        # Attention: Example -> Features -> Feature
        example = tf.train.Example(features=tf.train.Features(feature={
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
        }))
        writer.write(example.SerializeToString())
    writer.close()


if __name__ == '__main__':
    path = 'kaggle/train'
    tfrecords_path = 'tfrecords/cat_dog.tfrecords'
    data = read_images(path)
    convert(data,tfrecords_paths)

然后直接命令python create_tfrecords.py

然后直接命令python catdog_train.py --tfrecords_name cat_dog \ --max_step 5000

运行结果:


猜你喜欢

转载自blog.csdn.net/pursuit_zhangyu/article/details/80581215