浅谈GAN网络-二

本文接着上一篇不太成功的博客https://blog.csdn.net/qq_30666517/article/details/81262849继续进行,昨天花了点时间,使之前的工作有一个较好的结果。上一篇博客是自己第一次接触GAN时所写,自己尝试写了一个最简单的全连接GAN网络,结果输出效果很差,这篇博客就是分析上篇不太好结果的原因,以及对于训练GAN一些体会。

上篇博客代码中有2个缺陷,也是比较容易忽略的问题。第一:生成器随机噪声输入,应在训练的时候动态给定,也就是说生成器噪声输入应该定义一个占位符,之前的代码直接是noise_inputs=tf.random_normal(shape=[128,100]),这样这条语句只会执行一次,每次训练的时候会导致生成器输入的随机噪声是一个固定的一维向量,本博客代码对此做了修正。第二:GoodFellow原始GAN论文生成器最后一层激励函数使用的是sigmoid函数,后来有论文指出使用tanh函数收敛更快效果更好,上篇博客也是使用tanh函数,但是发现效果很差,后来自己做了一下实验,把tanh改成sigmoid生成的质量更高,这就造成了一个矛盾,后来仔细检查代码,发现了一个问题所在。一般我们习惯上将输入图片简单归一化到[0,1],但是这时使用sigmoid激活函数比tanh更好,因为sigmoid值域就是[0,1],而tanh函数值域是[-1,1],所以如果生成器最后一层采用tanh激活函数,切记要把输入缩放在[-1,1]之间,缩放方式可以简单采取为:img = img/255.0*2-1.0,这样发现GAN的训练更稳定,收敛速度确实更快,生成质量确实更高。

一点小小的经验:

1.生成器和判别器最好都使用leaky_relu作为激活函数,出了生成器最后一层使用tanh,以及判别器最后一层使用sigmoid之外,这样可以防止在训练过成中的梯度消失现象,效果确实比用relu作为激活函数,得到的图片质量更高。

2.由于GAN的终极目标还是为了得到高质量的生成图片,判别器的工作相对与生成器的工作更轻松,所以建议在写生成器代码时,生成器比判别器网络更复杂,会有更好的效果。

3.优化器采用DCGAN中的结论,Adam优化器,lr=2e-4,beta1=0.5

附上几张重新生成的图片:

 

 

相比上一篇博客,已经好了很多了,也差不多到了全连接GAN的极限了。

附上改进后的代码:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 27 15:06:58 2018

@author: wsw
"""

# create mnist fc GAN

import tensorflow as tf
import numpy as np
import os
import time
import matplotlib.pyplot as plt
slim = tf.contrib.slim

tf.reset_default_graph()

def load_data():
    dataPath = '../mnist/train.npy'
    trainData = np.load(dataPath)
    return trainData



def get_next_batch(Datas,batchsize=128):
    image = tf.train.slice_input_producer([Datas],
                                          num_epochs=200,
                                          shuffle=True,
                                          )
    # form a batch date
    image_batch = tf.train.batch([image],
                                 batch_size=batchsize,
                                 capacity=1000,
                                 num_threads=4)
    return image_batch


def generator(inputs,is_training=True):
    
    
    with slim.arg_scope([slim.fully_connected],
                        activation_fn=tf.nn.leaky_relu,
                        ):
        net = slim.fully_connected(inputs,num_outputs=128,scope='fc1')
        net = slim.fully_connected(net,num_outputs=256,scope='fc2')
        net = slim.fully_connected(net,num_outputs=512,scope='fc3')
        net = slim.fully_connected(net,num_outputs=784,
                                   activation_fn=tf.nn.tanh,
                                   scope='fc4')
        return net

    
def discriminator(inputs,is_training=True):
    
    with slim.arg_scope([slim.fully_connected],
                        activation_fn=tf.nn.leaky_relu,
                        ):
        net = slim.fully_connected(inputs,num_outputs=512,scope='fc1')
        net = slim.dropout(net,keep_prob=0.8,is_training=is_training)
        net = slim.fully_connected(net,num_outputs=1,
                                   activation_fn=None,
                                   scope='fc2')

        return net


def train_GAN():
    # load data
    train_Data = load_data()
    train_Data = np.float32(train_Data/255.0)*2-1.0
    # get batch data
    batchsize = 128
    image_batch = get_next_batch(train_Data,batchsize=batchsize)
    # build gan model
    # random noise inputs
    noise_inputs = tf.placeholder(tf.float32,shape=[batchsize,100])
    
    with tf.variable_scope('Generator',reuse=tf.AUTO_REUSE):
        gen_imgs = generator(noise_inputs)
        
    with tf.variable_scope('Discriminator',reuse=tf.AUTO_REUSE):
        # compute d_loss
        # true data score
        d_truth_score = discriminator(image_batch)
        # fake data score
        d_fake_score = discriminator(gen_imgs,is_training=True)
        g_fake_score = discriminator(gen_imgs,is_training=False)
    
    with tf.name_scope('compute_accuracy'):
        predict_fake_label = tf.where(tf.nn.sigmoid(g_fake_score)>0.5,
                                      tf.ones([batchsize,1],
                                               dtype=tf.uint8),
                                      tf.zeros([batchsize,1],
                                               dtype=tf.uint8))
                                      
        gt_fake_label = tf.zeros(shape=[batchsize,1],dtype=tf.uint8)
        accuracy = tf.reduce_mean(tf.cast(tf.equal(predict_fake_label,gt_fake_label),
                                          dtype=tf.float32))
        
    with tf.name_scope('D_loss'):
        # d_loss = -tf.reduce_mean(tf.log(truth_score)+tf.log(1-fake_score))
        d_truth_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_truth_score,
                                                               labels=tf.ones_like(d_truth_score))
        d_fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_score,
                                                              labels=tf.zeros_like(d_fake_score))
        d_loss = tf.reduce_mean(d_fake_loss+d_truth_loss)
    
    with tf.name_scope('G_loss'):
        # g_loss = -tf.reduce_mean(tf.log(fake_score))
        g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=g_fake_score,
                                                        labels=tf.ones_like(g_fake_score)))
    
    
    
    with tf.name_scope('optimizer'):
        global_step = tf.train.create_global_step()
        optimizer = tf.train.AdamOptimizer(2e-4,beta1=0.5)
        # get generator variable list
        gen_vars = slim.get_variables(scope='Generator')
        # get discriminator variable list
        disc_vars = slim.get_variables(scope='Discriminator')
        print('Generator Trainable Variables',gen_vars)
        print('Discriminator Trainable Variables',disc_vars)
        train_G = optimizer.minimize(g_loss,global_step,var_list=gen_vars)
        train_D = optimizer.minimize(d_loss,global_step,var_list=disc_vars)
        
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        # must needed 
        tf.local_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess,coord)
        try:
            epoch = 1
            while not coord.should_stop():
                # random noise inputs
                noise = np.random.normal(size=[128,100]).astype(np.float32)
                start = time.time()
                # train Discriminator
                d_loss_value,_,accu = sess.run([d_loss,train_D,accuracy],
                                               feed_dict={noise_inputs:noise})
                # train Generator
                g_loss_value,_ = sess.run([g_loss,train_G],
                                          feed_dict={noise_inputs:noise})
                end = time.time()
                step = global_step.eval()//2
                fmt = 'Epoch:{:02d}-Step:{:05d}-Gloss:{:.3f}-Dloss:{:.3f}-D_accu:{:.5f}-Elapsed:{:.3f}(Sec)'\
                .format(epoch,step,g_loss_value,d_loss_value,accu,end-start)
                if step%100==0:
                    print(fmt)
                if step%470==0:
                    epoch += 1
                if step%4700 == 0:
                    valid_imgs_out = sess.run(gen_imgs,
                                              feed_dict={noise_inputs:noise})
                    show_result(valid_imgs_out,epoch)
                    
        except tf.errors.OutOfRangeError: 
            coord.request_stop()
            print('train finished!!!')
        coord.join(threads)


        
def show_result(valid_imgs,epoch):
    imgDir = './images2'
    if not os.path.exists(imgDir):
        os.mkdir(imgDir)
    fig,ax = plt.subplots(5,5,sharex='all',sharey='all')
    for idx in range(25):
        img = valid_imgs[idx]
        img = img.reshape(28,28)
        row, col_index = divmod(idx,5)
        col = col_index%5
        ax[row,col].imshow(img,cmap='gray')
        ax[row,col].axis('off')
    fig.savefig(os.path.join(imgDir,'%d.png'%epoch))
    plt.close()
    
    
if __name__ == '__main__':
    train_GAN()

猜你喜欢

转载自blog.csdn.net/qq_30666517/article/details/81866664