对抗神经网络学习(九)——CartoonGAN+爬虫生成《言叶之庭》风格的影像(tensorflow实现)

一、背景

cartoonGAN是Yang Chen等人于2018年2月提出的一种模型。该模型针对漫画风格图像生成做了进一步研究,提出了新的GAN网络结构和两种损失函数,相较于之前的漫画风格生成的GAN模型,cartoonGAN的生成漫画风格的图像质量有了明显提高。

本实验通过自己爬取《言叶之庭》(新海城的动漫)的影像进行实验,以生成相应风格的动漫影像。

[1]文章链接:http://openaccess.thecvf.com/content_cvpr_2018/papers/Chen_CartoonGAN_Generative_Adversarial_CVPR_2018_paper.pdf

二、CartoonGAN原理

由于该文章比较新,网上的介绍不多,先推荐一篇网上关于cartoonGAN讲解的文章:

[2]实景照片秒变新海诚风格漫画:清华大学提出CartoonGAN

提到漫画风格转换,前面做过cycleGAN,也能够实现,但是作者对这类模型进行细致研究,发现这类模型对漫画风格的影像生成质量并不好,这是因为:

However, existing methods do not produce satisfactory results for cartoonization, due to the fact that (1) cartoon styles have unique characteristics with high level simplification and abstraction, and (2) cartoon images tend to have clear edges, smooth color shading and relatively simple textures, which exhibit significant challenges for texture-descriptor-based loss functions used in existing methods. 

原因有二:一是因为卡通风格影像具有高度简化和抽象的特征;二是因为卡通影像是有着非常清晰的边界,平滑的色彩,简单的纹理。

作者提到了两点,重点还是因为先前的模型采用基于纹理判别的loss函数,无法生成较为清晰的边界和平滑的色彩,为了说明这个问题,作者将卷积神经网络NST,cycleGAN与cartoonGAN做了进一步对比:

针对传统模型存在的问题,作者引入了cartoonGAN,其文章的贡献主要有三个方面:

The main contributions of this paper are: 

(1) We propose a dedicated GAN-based approach that effectively learns the mapping from real-world photos to cartoon images using unpaired image sets for training. Our method is able to generate high-quality stylized cartoons, which are substantially better than state-of-the-art methods. When cartoon images from individual artists are used for training, our method is able to reproduce their styles. (提出新的GAN结构,利用非成对数据,实现现实影像到卡通影像的转换)

(2) We propose two simple yet effective loss functions in GAN-based architecture. In the generative network, to cope with substantial style variation between photos and cartoons, we introduce a semantic loss defined as an ℓ1 sparse regularization in the high-level feature maps of the VGG network. In the discriminator network, we propose an edge-promoting adversarial loss for preserving clear edges. (提出了两种loss函数。在生成器中引入semantic loss,在判别器中引入边缘推进对抗loss函数)

(3) We further introduce an initialization phase to improve the convergence of the network to the target manifold. Our method is much more efficient to train than existing methods.(引入初始化阶段改善网络到目标流的收敛)

接下来作者介绍了cartoonGAN的网络结构,整体框架是由两个CNN结构组成,网络结构的示意图如下所示:

总的来看,生成器的网络结构类似于自编码器,先下采样再上采样。判别器的结构类似于普通的CNN。另外,作者的关键改进在于引入了两种loss。

作者还提到,随机初始化使得传统的GAN模型高度非线性化,其优化过程容易陷入次优局部极小值。注意到生成器的生成影像都是具有一定语义内容(semantic content)的,因此在预训练生成器时只用content loss。

接下来就是作者的实验了,作者的现实影像是从Flickr上下载下来的(以前需要翻墙才能浏览Flickr),一共6154张,其中5402张用于训练;卡通影像是从video中截取的(应该是隔几帧截一次),一共4212张。准备好影像后再将所有的数据裁剪为256*256大小。

下面直接给出作者的实验结果:

从结果图上看效果还不错。后续作者还将该模型与其他模型进行了对比,有兴趣的话可以阅读原文。

关于模型的实现,作者使用的是pytorch,不过幸运的是可以在github上找到tensorflow版本的代码,给出两个参考:

[3]https://github.com/taki0112/CartoonGAN-Tensorflow

[4]https://github.com/SystemErrorWang/CartoonGAN

我主要参考了代码[3],原代码真的写的太好了,因此我只做了很小的修改,下面再讲讲具体是怎么实现的。

三、CartoonGAN实现

1. 文件结构

所有文件的结构如下:

-- utils.py
-- layer.py
-- vgg19.py
-- vgg19.npy                            # 这个文件需要自己下载,后面会讲到
-- edge_smooth.py
-- cartoonGAN.py
-- main.py
-- dataset                              # 数据集文件  
    |------ trainA                      # cartoon影像数据 
                |------ image1.png
                |------ image2.png
                |------ ......
    |------ trainB                      # 现实影像 
                |------ image1.png
                |------ image2.png
                |------ ......
    |------ trainB_smooth               # 利用edge_smooth可以制作  
                |------ image1.png
                |------ image2.png
                |------ ......
    |------ testA                       # 最终的测试影像 
                |------ image1.png
                |------ image2.png
                |------ ......

2. 数据准备

这里需要准备的数据包含两种:

(1)vgg19.npy

由于原模型中用到了vgg19模型,所以我们需要先下载这个文件,跟前面的DeblurGAN中介绍的一样,该文件也需要翻墙下载,如果之前做DeblurGAN的时候下载过这个文件,那这次则不用下载,直接使用上次下载好的vgg19模型就可以。这里只给出下载地址:

https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs

需要注意的是,我第一次翻墙时候的IP不是美国,它竟然还限制下载,必须要我买会员才能完整下载,直接设置登陆地址为美国,就可以直接下载

为了方便使用,我将该数据上传到了百度云上。下载地址为:

百度云地址:https://pan.baidu.com/s/1GluBif6N1u9eiosICI12Ng

提取码:dzsa

下载好该文件之后,将该文件放到项目的根目录下即可,即'./vgg19.npy'。

(2)数据集

所有数据集都是我自己爬下来并预处理的。这次卡通影像我选择的是《言叶之庭》(新海诚的作品),对其进行爬取并处理(resize成256*256);而现实影像,考虑到之前做过cycleGAN,里面用到了现实影像与梵高风格的数据[5],这次就直接拿过来了。关于如何爬取数据,如何获取现实影像,可以参考之前的文章[5]和[6]:

[5]对抗神经网络学习(三)——cycleGAN实现Van Gogh风格的图像转换(tensorflow实现)

[6]用python爬取图片的一点小结

当然,我也将我的数据集上传到了我的资源当中,如果需要的话可以自行下载。

这里先给出下载地址,由于数据集是自己费了一定精力制作的,所以就收2积分,下载地址为:

https://download.csdn.net/download/z704630835/10801038

由于上传数据必须小于220MB,所以我将现实影像数据删减至了5400张,不过这并不影响模型的训练。

我做好的数据集大概为:

其中,trainA文件夹一共有647张影像,trainB文件夹中一共有6277张影像。需要注意的是,我们准备好数据集之后,还需要利用edge_smooth.py文件生成trainB_smooth文件数据,这一步只需要再edge_smooth.py文件中设置好相关路径,直接运行即可。

3. 加载数据的相关文件utils.py

utils.py中主要是一些关于加载图像的函数,具体代码为:

import tensorflow as tf
from tensorflow.contrib import slim
from scipy import misc
import os
import numpy as np


class ImageData:

    def __init__(self, load_size, channels):
        self.load_size = load_size
        self.channels = channels

    def image_processing(self, filename):
        x = tf.read_file(filename)
        x_decode = tf.image.decode_jpeg(x, channels=self.channels)
        img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
        img = tf.cast(img, tf.float32) / 127.5 - 1

        return img


def load_test_data(image_path, size=256):
    img = misc.imread(image_path, mode='RGB')
    img = misc.imresize(img, [size, size])
    img = np.expand_dims(img, axis=0)
    img = preprocessing(img)

    return img


def preprocessing(x):
    x = x/127.5 - 1                     # -1 ~ 1
    return x


def save_images(images, size, image_path):
    return imsave(inverse_transform(images), size, image_path)


def inverse_transform(images):
    return (images+1.) / 2


def imsave(images, size, path):
    return misc.imsave(path, merge(images, size))


def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[h*j:h*(j+1), w*i:w*(i+1), :] = image

    return img


def show_all_variables():
    model_vars = tf.trainable_variables()
    slim.model_analyzer.analyze_vars(model_vars, print_info=True)


def check_folder(log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    return log_dir


def str2bool(x):
    return x.lower() in 'true'

4. 模型图层定义文件layer.py

layer.py中主要定义了一些图层函数,具体代码为:

import tensorflow as tf
import tensorflow.contrib as tf_contrib
from vgg19 import Vgg19

weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = None

##################################################################################
# Layer
##################################################################################


def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
    with tf.variable_scope(scope):
        if (kernel - stride) % 2 == 0:
            pad_top = pad
            pad_bottom = pad
            pad_left = pad
            pad_right = pad

        else :
            pad_top = pad
            pad_bottom = kernel - stride - pad_top
            pad_left = pad
            pad_right = kernel - stride - pad_left

        if pad_type == 'zero':
            x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
        if pad_type == 'reflect':
            x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode='REFLECT')

        if sn:
            w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
                                regularizer=weight_regularizer)
            x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
                             strides=[1, stride, stride, 1], padding='VALID')
            if use_bias:
                bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
                x = tf.nn.bias_add(x, bias)

        else:
            x = tf.layers.conv2d(inputs=x, filters=channels,
                                 kernel_size=kernel, kernel_initializer=weight_init,
                                 kernel_regularizer=weight_regularizer,
                                 strides=stride, use_bias=use_bias)

        return x


def deconv(x, channels, kernel=4, stride=2, use_bias=True, sn=False, scope='deconv_0'):
    with tf.variable_scope(scope):
        x_shape = x.get_shape().as_list()
        output_shape = [x_shape[0], x_shape[1]*stride, x_shape[2]*stride, channels]
        if sn:
            w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]],
                                initializer=weight_init, regularizer=weight_regularizer)
            x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape,
                                       strides=[1, stride, stride, 1], padding='SAME')

            if use_bias:
                bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
                x = tf.nn.bias_add(x, bias)

        else:
            x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
                                           kernel_size=kernel, kernel_initializer=weight_init,
                                           kernel_regularizer=weight_regularizer,
                                           strides=stride, padding='SAME', use_bias=use_bias)

        return x


##################################################################################
# Residual-block
##################################################################################


def resblock(x_init, channels, use_bias=True, scope='resblock_0'):
    with tf.variable_scope(scope):
        with tf.variable_scope('res1'):
            x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
            x = instance_norm(x)
            x = relu(x)

        with tf.variable_scope('res2'):
            x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias)
            x = instance_norm(x)

        return x + x_init

##################################################################################
# Activation function
##################################################################################


def lrelu(x, alpha=0.2):
    return tf.nn.leaky_relu(x, alpha)


def relu(x):
    return tf.nn.relu(x)


def tanh(x):
    return tf.tanh(x)

##################################################################################
# Normalization function
##################################################################################


def instance_norm(x, scope='instance_norm'):
    return tf_contrib.layers.instance_norm(x,
                                           epsilon=1e-05,
                                           center=True, scale=True,
                                           scope=scope)


def spectral_norm(w, iteration=1):
    w_shape = w.shape.as_list()
    w = tf.reshape(w, [-1, w_shape[-1]])

    u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)

    u_hat = u
    v_hat = None
    for i in range(iteration):
        """
        power iteration
        Usually iteration = 1 will be enough
        """
        v_ = tf.matmul(u_hat, tf.transpose(w))
        v_hat = l2_norm(v_)

        u_ = tf.matmul(v_hat, w)
        u_hat = l2_norm(u_)

    sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
    w_norm = w / sigma

    with tf.control_dependencies([u.assign(u_hat)]):
        w_norm = tf.reshape(w_norm, w_shape)

    return w_norm


def l2_norm(v, eps=1e-12):
    return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)


##################################################################################
# Loss function
##################################################################################


def L1_loss(x, y):
    return tf.reduce_mean(tf.abs(x - y))


def discriminator_loss(loss_func, real, fake, real_blur):
    real_loss = 0
    fake_loss = 0
    real_blur_loss = 0

    if loss_func == 'wgan-gp' or loss_func == 'wgan-lp':
        real_loss = -tf.reduce_mean(real)
        fake_loss = tf.reduce_mean(fake)
        real_blur_loss = tf.reduce_mean(real_blur)

    if loss_func == 'lsgan':
        real_loss = tf.reduce_mean(tf.square(real - 1.0))
        fake_loss = tf.reduce_mean(tf.square(fake))
        real_blur_loss = tf.reduce_mean(tf.square(real_blur))

    if loss_func == 'gan' or loss_func == 'dragan':
        real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
        fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
        real_blur_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real_blur),
                                                                                logits=real_blur))

    if loss_func == 'hinge':
        real_loss = tf.reduce_mean(relu(1.0 - real))
        fake_loss = tf.reduce_mean(relu(1.0 + fake))
        real_blur_loss = tf.reduce_mean(relu(1.0 + real_blur))

    loss = real_loss + fake_loss + real_blur_loss

    return loss


def generator_loss(loss_func, fake):
    fake_loss = 0

    if loss_func == 'wgan-gp' or loss_func == 'wgan-lp':
        fake_loss = -tf.reduce_mean(fake)

    if loss_func == 'lsgan' :
        fake_loss = tf.reduce_mean(tf.square(fake - 1.0))

    if loss_func == 'gan' or loss_func == 'dragan':
        fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))

    if loss_func == 'hinge':
        fake_loss = -tf.reduce_mean(fake)

    loss = fake_loss

    return loss


def vgg_loss(real, fake):
    vgg = Vgg19('vgg19.npy')

    vgg.build(real)
    real_feature_map = vgg.conv4_4_no_activation

    vgg.build(fake)
    fake_feature_map = vgg.conv4_4_no_activation

    loss = L1_loss(real_feature_map, fake_feature_map)

    return loss

5. vgg19模型文件vgg19.py

由于和DeblurGAN中的模型一致,因此也没做改动,直接给出代码:

import tensorflow as tf
import numpy as np
import time

VGG_MEAN = [103.939, 116.779, 123.68]


class Vgg19:

    def __init__(self, vgg19_npy_path=None):
        self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item()
        print("npy file loaded")

    def build(self, rgb):
        """
        load variable from npy to build the VGG
        input format: bgr image with shape [batch_size, h, w, 3]
        scale: (-1, 1)
        """

        start_time = time.time()
        rgb_scaled = ((rgb + 1) / 2) * 255.0        # [-1, 1] ~ [0, 255]

        red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
        bgr = tf.concat(axis=3, values=[blue - VGG_MEAN[0],
                                        green - VGG_MEAN[1],
                                        red - VGG_MEAN[2]])

        self.conv1_1 = self.conv_layer(bgr, "conv1_1")
        self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
        self.pool1 = self.max_pool(self.conv1_2, 'pool1')

        self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
        self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
        self.pool2 = self.max_pool(self.conv2_2, 'pool2')

        self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
        self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
        self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
        self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4")
        self.pool3 = self.max_pool(self.conv3_4, 'pool3')

        self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
        self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
        self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")

        self.conv4_4_no_activation = self.no_activation_conv_layer(self.conv4_3, "conv4_4")

        self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4")
        self.pool4 = self.max_pool(self.conv4_4, 'pool4')

        self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
        self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
        self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
        self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4")
        self.pool5 = self.max_pool(self.conv5_4, 'pool5')

        print(("Finished building vgg19: %ds" % (time.time() - start_time)))

    def max_pool(self, bottom, name):
        return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)

    def conv_layer(self, bottom, name):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name)

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name)
            bias = tf.nn.bias_add(conv, conv_biases)

            relu = tf.nn.relu(bias)
            return relu

    def no_activation_conv_layer(self, bottom, name):
        with tf.variable_scope(name):
            filt = self.get_conv_filter(name)

            conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')

            conv_biases = self.get_bias(name)
            x = tf.nn.bias_add(conv, conv_biases)

            return x

    def get_conv_filter(self, name):
        return tf.constant(self.data_dict[name][0], name="filter")

    def get_bias(self, name):
        return tf.constant(self.data_dict[name][1], name="biases")

6. 高斯平滑文件edge_smooth.py

edge_smooth.py文件主要实现的功能是对图像做边缘平滑,具体代码为:

from utils import check_folder
import numpy as np
import cv2, os, argparse
from glob import glob
from tqdm import tqdm

def parse_args():
    desc = "Edge smoothed"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--dataset', type=str, default='img2anime', help='dataset_name')
    parser.add_argument('--img_size', type=int, default=256, help='The size of image')

    return parser.parse_args()

def make_edge_smooth(dataset_name, img_size) :
    check_folder('./dataset/trainB_smooth/')

    file_list = glob('./dataset/trainB/*.*')
    save_dir = './dataset/{}/trainB_smooth'.format(dataset_name)

    kernel_size = 5
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    gauss = cv2.getGaussianKernel(kernel_size, 0)
    gauss = gauss * gauss.transpose(1, 0)

    for f in tqdm(file_list):
        file_name = os.path.basename(f)

        bgr_img = cv2.imread(f)
        gray_img = cv2.imread(f, 0)

        bgr_img = cv2.resize(bgr_img, (img_size, img_size))
        gray_img = cv2.resize(gray_img, (img_size, img_size))

        # 计算边缘,并进行膨胀算法处理
        edges = cv2.Canny(gray_img, 100, 200)
        dilation = cv2.dilate(edges, kernel)

        h, w = edges.shape

        # 进行高斯模糊
        gauss_img = np.copy(bgr_img)
        for i in range(kernel_size // 2, h - kernel_size // 2):
            for j in range(kernel_size // 2, w - kernel_size // 2):
                if dilation[i, j] != 0:  # gaussian blur to only edge
                    gauss_img[i, j, 0] = np.sum(np.multiply(bgr_img[i - kernel_size // 2:i + kernel_size // 2 + 1,
                                                            j - kernel_size // 2:j + kernel_size // 2 + 1, 0], gauss))
                    gauss_img[i, j, 1] = np.sum(np.multiply(bgr_img[i - kernel_size // 2:i + kernel_size // 2 + 1,
                                                            j - kernel_size // 2:j + kernel_size // 2 + 1, 1], gauss))
                    gauss_img[i, j, 2] = np.sum(np.multiply(bgr_img[i - kernel_size // 2:i + kernel_size // 2 + 1,
                                                            j - kernel_size // 2:j + kernel_size // 2 + 1, 2], gauss))

        cv2.imwrite(os.path.join(save_dir, file_name), gauss_img)

"""main"""
def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    make_edge_smooth(args.dataset, args.img_size)


if __name__ == '__main__':
    main()

7. 模型文件cartoonGAN.py

cartoonGAN中主要定义模型结构,以及与模型相关的操作,具体代码为:

from layer import *
from utils import *
from glob import glob
import time
from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
import numpy as np


class CartoonGAN(object):
    def __init__(self, sess, args):
        self.model_name = 'CartoonGAN'
        self.sess = sess
        self.checkpoint_dir = args.checkpoint_dir
        self.result_dir = args.result_dir
        self.log_dir = args.log_dir
        self.dataset_name = args.dataset

        self.epoch = args.epoch
        self.init_epoch = args.init_epoch       # args.epoch // 20
        self.iteration = args.iteration
        self.decay_flag = args.decay_flag
        self.decay_epoch = args.decay_epoch

        self.gan_type = args.gan_type

        self.batch_size = args.batch_size
        self.print_freq = args.print_freq
        self.save_freq = args.save_freq

        self.init_lr = args.lr
        self.ch = args.ch

        """ Weight """
        self.adv_weight = args.adv_weight
        self.vgg_weight = args.vgg_weight
        self.ld = args.ld

        """ Generator """
        self.n_res = args.n_res

        """ Discriminator """
        self.n_dis = args.n_dis
        self.n_critic = args.n_critic
        self.sn = args.sn

        self.img_size = args.img_size
        self.img_ch = args.img_ch

        self.sample_dir = os.path.join(args.sample_dir, self.model_dir)
        check_folder(self.sample_dir)

        self.trainA_dataset = glob('./dataset/trainA/*.*')
        self.trainB_dataset = glob('./dataset/trainB/*.*')
        self.trainB_smooth_dataset = glob('./dataset/trainB_smooth/*.*')

        self.dataset_num = max(len(self.trainA_dataset), len(self.trainB_dataset))

        print()

        print("##### Information #####")
        print("# gan type : ", self.gan_type)
        print("# dataset : ", self.dataset_name)
        print("# max dataset number : ", self.dataset_num)
        print("# batch_size : ", self.batch_size)
        print("# epoch : ", self.epoch)
        print("# init_epoch : ", self.init_epoch)
        print("# iteration per epoch : ", self.iteration)

        print()

        print("##### Generator #####")
        print("# residual blocks : ", self.n_res)

        print()

        print("##### Discriminator #####")
        print("# the number of discriminator layer : ", self.n_dis)
        print("# the number of critic : ", self.n_critic)
        print("# spectral normalization : ", self.sn)

        print()

    ##################################################################################
    # Generator
    ##################################################################################

    def generator(self, x_init, reuse=False, scope="generator"):
        channel = self.ch
        with tf.variable_scope(scope, reuse=reuse):
            x = conv(x_init, channel, kernel=7, stride=1, pad=3, pad_type='reflect', use_bias=False, scope='conv')
            x = instance_norm(x, scope='ins_norm')
            x = relu(x)

            # Down-Sampling
            for i in range(2):
                x = conv(x, channel*2, kernel=3, stride=2, pad=1, use_bias=True, scope='conv_s2_'+str(i))
                x = conv(x, channel*2, kernel=3, stride=1, pad=1, use_bias=False, scope='conv_s1_'+str(i))
                x = instance_norm(x, scope='ins_norm_'+str(i))
                x = relu(x)

                channel = channel * 2

            # Bottleneck
            for i in range(self.n_res):
                x = resblock(x, channel, use_bias=False, scope='resblock_' + str(i))

            # Up-Sampling
            for i in range(2):
                x = deconv(x, channel//2, kernel=3, stride=2, use_bias=True, scope='deconv_'+str(i))
                x = conv(x, channel//2, kernel=3, stride=1, pad=1, use_bias=False, scope='up_conv_'+str(i))
                x = instance_norm(x, scope='up_ins_norm_'+str(i))
                x = relu(x)

                channel = channel // 2

            x = conv(x, channels=self.img_ch, kernel=7, stride=1, pad=3, pad_type='reflect', 
                     use_bias=True, scope='G_logit')
            x = tanh(x)

            return x

    ##################################################################################
    # Discriminator
    ##################################################################################

    def discriminator(self, x_init, reuse=False, scope="discriminator"):
        channel = self.ch // 2
        with tf.variable_scope(scope, reuse=reuse):
            x = conv(x_init, channel, kernel=3, stride=1, pad=1, use_bias=True, sn=self.sn, scope='conv_0')
            x = lrelu(x, 0.2)

            for i in range(1, self.n_dis):
                x = conv(x, channel * 2, kernel=3, stride=2, pad=1, use_bias=True, sn=self.sn, scope='conv_s2_' + str(i))
                x = lrelu(x, 0.2)

                x = conv(x, channel * 4, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='conv_s1_' + str(i))
                x = instance_norm(x, scope='ins_norm_' + str(i))
                x = lrelu(x, 0.2)

                channel = channel * 2

            x = conv(x, channel * 2, kernel=3, stride=1, pad=1, use_bias=False, sn=self.sn, scope='last_conv')
            x = instance_norm(x, scope='last_ins_norm')
            x = lrelu(x, 0.2)

            x = conv(x, channels=1, kernel=3, stride=1, pad=1, use_bias=True, sn=self.sn, scope='D_logit')

            return x

    ##################################################################################
    # Model
    ##################################################################################
    def gradient_panalty(self, real, fake, scope="discriminator"):
        if self.gan_type.__contains__('dragan'):
            eps = tf.random_uniform(shape=tf.shape(real), minval=0., maxval=1.)
            _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
            x_std = tf.sqrt(x_var)  # magnitude of noise decides the size of local region

            fake = real + 0.5 * x_std * eps

        alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
        interpolated = real + alpha * (fake - real)

        logit = self.discriminator(interpolated, reuse=True, scope=scope)

        grad = tf.gradients(logit, interpolated)[0]             # gradient of D(interpolated)
        grad_norm = tf.norm(tf.layers.flatten(grad), axis=1)    # l2 norm

        GP = 0
        # WGAN - LP
        if self.gan_type.__contains__('lp'):
            GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))

        elif self.gan_type.__contains__('gp') or self.gan_type == 'dragan' :
            GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))

        return GP

    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')

        """ Input Image"""
        Image_Data_Class = ImageData(self.img_size, self.img_ch)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
        trainB_smooth = tf.data.Dataset.from_tensor_slices(self.trainB_smooth_dataset)

        gpu_device = '/gpu:0'
        trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16,
                          drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16,
                          drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB_smooth = trainB_smooth.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(
            Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16,
            drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()
        trainB_smooth_iterator = trainB_smooth.make_one_shot_iterator()

        self.real_A = trainA_iterator.get_next()
        self.real_B = trainB_iterator.get_next()
        self.real_B_smooth = trainB_smooth_iterator.get_next()

        self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_real_A')

        """ Define Generator, Discriminator """
        self.fake_B = self.generator(self.real_A)

        real_B_logit = self.discriminator(self.real_B)
        fake_B_logit = self.discriminator(self.fake_B, reuse=True)
        real_B_smooth_logit = self.discriminator(self.real_B_smooth, reuse=True)

        """ Define Loss """
        if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') :
            GP = self.gradient_panalty(real=self.real_B, fake=self.fake_B) + self.gradient_panalty(self.real_B, fake=self.real_B_smooth)
        else :
            GP = 0.0

        v_loss = self.vgg_weight * vgg_loss(self.real_A, self.fake_B)
        g_loss = self.adv_weight * generator_loss(self.gan_type, fake_B_logit)
        d_loss = self.adv_weight * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, real_B_smooth_logit) + GP

        self.Vgg_loss = v_loss
        self.Generator_loss = g_loss + v_loss
        self.Discriminator_loss = d_loss

        """ Result Image """
        self.test_fake_B = self.generator(self.test_real_A, reuse=True)

        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.init_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Vgg_loss, var_list=G_vars)
        self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)

        """" Summary """
        self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
        self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)

        self.G_gan = tf.summary.scalar("G_gan", g_loss)
        self.G_vgg = tf.summary.scalar("G_vgg", v_loss)

        self.V_loss_merge = tf.summary.merge([self.G_vgg])
        self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg])
        self.D_loss_merge = tf.summary.merge([self.D_loss])

    def train(self):
        # initialize all variables
        tf.global_variables_initializer().run()

        # saver to save model
        self.saver = tf.train.Saver()

        # summary writer
        self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.iteration)
            start_batch_id = checkpoint_counter - start_epoch * self.iteration
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        past_g_loss = -1.
        lr = self.init_lr
        for epoch in range(start_epoch, self.epoch):
            # lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch)
            if self.decay_flag :
                lr = self.init_lr * pow(0.5, epoch // self.decay_epoch)

            for idx in range(start_batch_id, self.iteration):

                train_feed_dict = {self.lr: lr}

                if epoch < self.init_epoch:
                    # Init G
                    real_A_images, fake_B_images, _, v_loss, summary_str = self.sess.run([self.real_A, self.fake_B,
                                                                             self.init_optim,
                                                                             self.Vgg_loss, self.V_loss_merge], feed_dict = train_feed_dict)
                    self.writer.add_summary(summary_str, counter)
                    print("Epoch: [%3d] [%5d/%5d] time: %4.4f v_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, v_loss))

                else:
                    # Update D
                    _, d_loss, summary_str = self.sess.run([self.D_optim, self.Discriminator_loss, self.D_loss_merge], feed_dict = train_feed_dict)
                    self.writer.add_summary(summary_str, counter)

                    # Update G
                    g_loss = None
                    if (counter - 1) % self.n_critic == 0 :
                        real_A_images, fake_B_images, _, g_loss, summary_str = self.sess.run([self.real_A, self.fake_B,
                                                                                              self.G_optim,
                                                                                              self.Generator_loss, self.G_loss_merge], feed_dict = train_feed_dict)
                        self.writer.add_summary(summary_str, counter)
                        past_g_loss = g_loss

                    if g_loss == None:
                        g_loss = past_g_loss
                    print("Epoch: [%3d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))

                # display training status
                counter += 1

                if np.mod(idx+1, self.print_freq) == 0 :
                    save_images(real_A_images, [self.batch_size, 1],
                                './{}/real_A_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
                    save_images(fake_B_images, [self.batch_size, 1],
                                './{}/fake_B_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))

                if np.mod(idx + 1, self.save_freq) == 0:
                    self.save(self.checkpoint_dir, counter)

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model for final step
            self.save(self.checkpoint_dir, counter)

    @property
    def model_dir(self):
        n_res = str(self.n_res) + 'resblock'
        n_dis = str(self.n_dis) + 'dis'
        return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name,
                                                         self.gan_type, n_res, n_dis,
                                                         self.n_critic, self.sn,
                                                         int(self.adv_weight), int(self.vgg_weight))

    def save(self, checkpoint_dir, step):
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # checkpoint file information

        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # first line
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            counter = int(ckpt_name.split('-')[-1])
            print(" [*] Success to read {}".format(ckpt_name))
            return True, counter
        else:
            print(" [*] Failed to find a checkpoint")
            return False, 0

    def test(self):
        tf.global_variables_initializer().run()
        test_A_files = glob('./dataset/testA/*.*')

        self.saver = tf.train.Saver()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        self.result_dir = os.path.join(self.result_dir, self.model_dir)
        check_folder(self.result_dir)

        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        # write html for visual comparison
        index_path = os.path.join(self.result_dir, 'index.html')
        index = open(index_path, 'w')
        index.write("<html><body><table><tr>")
        index.write("<th>name</th><th>input</th><th>output</th></tr>")

        for sample_file in test_A_files:                           # A -> B
            print('Processing A image: ' + sample_file)
            sample_image = np.asarray(load_test_data(sample_file))
            image_path = os.path.join(self.result_dir, '{0}'.format(os.path.basename(sample_file)))

            fake_img = self.sess.run(self.test_fake_B, feed_dict={self.test_real_A: sample_image})
            save_images(fake_img, [1, 1], image_path)

            index.write("<td>%s</td>" % os.path.basename(image_path))

            index.write("<td><img src='%s' width='%d' height='%d'></td>" % (sample_file if os.path.isabs(
                sample_file) else ('../..' + os.path.sep + sample_file), self.img_size, self.img_size))
            index.write("<td><img src='%s' width='%d' height='%d'></td>" % (image_path if os.path.isabs(
                image_path) else ('../..' + os.path.sep + image_path), self.img_size, self.img_size))
            index.write("</tr>")

        index.close()

8. 主文件main.py

main.py主文件中主要定义了模型参数,以及训练和测试过程,具体代码为:

from CartoonGAN import CartoonGAN
import argparse
from utils import *

"""parsing and configuration"""

def parse_args():
    desc = "Tensorflow implementation of CartoonGAN"
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument('--phase', type=str, default='train', help='train or test ?')
    parser.add_argument('--dataset', type=str, default='img2anime', help='dataset_name')

    parser.add_argument('--epoch', type=int, default=500, help='The number of epochs to run')
    parser.add_argument('--init_epoch', type=int, default=1, help='The number of epochs for weight initialization')
    parser.add_argument('--iteration', type=int, default=500, help='The number of training iterations')
    parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
    parser.add_argument('--print_freq', type=int, default=100, help='The number of image_print_freq')
    parser.add_argument('--save_freq', type=int, default=100, help='The number of ckpt_save_freq')
    parser.add_argument('--decay_flag', type=str2bool, default=False, help='The decay_flag')
    parser.add_argument('--decay_epoch', type=int, default=10, help='decay epoch')

    parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
    parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
    parser.add_argument('--adv_weight', type=float, default=1.0, help='Weight about GAN')
    parser.add_argument('--vgg_weight', type=float, default=10.0, help='Weight about VGG19')
    parser.add_argument('--gan_type', type=str, default='gan', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge')

    parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
    parser.add_argument('--n_res', type=int, default=8, help='The number of resblock')

    parser.add_argument('--n_dis', type=int, default=3, help='The number of discriminator layer')
    parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
    parser.add_argument('--sn', type=str2bool, default=False, help='using spectral norm')

    parser.add_argument('--img_size', type=int, default=256, help='The size of image')
    parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
    # parser.add_argument('--augment_flag', type=str2bool, default=False, help='Image augmentation use or not')

    parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
                        help='Directory name to save the checkpoints')
    parser.add_argument('--result_dir', type=str, default='results',
                        help='Directory name to save the generated images')
    parser.add_argument('--log_dir', type=str, default='logs',
                        help='Directory name to save training logs')
    parser.add_argument('--sample_dir', type=str, default='samples',
                        help='Directory name to save the samples on training')

    return check_args(parser.parse_args())


"""checking arguments"""
def check_args(args):
    # --checkpoint_dir
    check_folder(args.checkpoint_dir)

    # --result_dir
    check_folder(args.result_dir)

    # --result_dir
    check_folder(args.log_dir)

    # --sample_dir
    check_folder(args.sample_dir)

    # --epoch
    try:
        assert args.epoch >= 1
    except:
        print('number of epochs must be larger than or equal to one')

    # --batch_size
    try:
        assert args.batch_size >= 1
    except:
        print('batch size must be larger than or equal to one')
    return args


"""main"""
def main():
    # parse arguments
    args = parse_args()

    check_args(args)

    # open session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        gan = CartoonGAN(sess, args)

        # build graph
        gan.build_model()

        # show network architecture
        show_all_variables()

        if args.phase == 'train':
            gan.train()
            print(" [*] Training finished!")

        if args.phase == 'test':
            gan.test()
            print(" [*] Test finished!")

if __name__ == '__main__':
    main()

四、实验结果

如何运行模型呢,首先准备好所用的数据集,然后在main.py中设置phase的值为'train',训练完毕之后,再将phase的值设置为'test',这样就能看到最终的实验结果了。

    parser.add_argument('--phase', type=str, default='test', help='train or test ?')

实验我只设置了100个epoch,每个epoch训练100个batch,batch_size设置为1(由于GPU内存比较小,只有3G,当batch_size设置为4的时候都会现实内存不够,所以最后就改成了1),相当于训练了1万次图像。不过可能由于训练的比较少,实验的结果不是非常明显,下面直接给出实验结果:

由于原始数据集本身就偏绿色一些,因此生成的数据有很明显的绿色偏向。


第二次更新:

昨天的训练效果不太好,个人觉得可能是训练次数还不够,因为结果具有非常明显的像素颗粒感,因此后来设置epoch为500,batch为500,也就是进行25万次图片的训练,下面先来看一下生成图像的效果图:

感觉。。。。压根没有生成cartoon图啊,虽然是比之前100个epoch的训练效果好了些,明显少了像素感和绿色斑块,但总的说来是看不出漫画风的迹象。。。。。。。

五、分析

1. 文件结构参见三

2. 对于漫画风格的影像生成,cartoonGAN的生成效果要明显好于cycleGAN及其相似的模型

猜你喜欢

转载自blog.csdn.net/z704630835/article/details/84336398