生成对抗网络GANs

生成对抗网络GANs(Generative Adversarial Nets
这里写图片描述

from datetime import datetime
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tensorflow.examples.tutorials.mnist import input_data

BATCH_SIZE = 128
LEARNING_RATE = 1e-4
Z_DIM = 100
IMAGE_W = 28
IMAGE_H = 28
model_dir = 'model_gan'

x_in = tf.placeholder(tf.float32, shape=[None, 784])


def load_mnist():
    return input_data.read_data_sets("./MNIST_data", one_hot=True)


mnist = load_mnist()


def get_W_b(input_dim, output_dim, name):
    W = tf.Variable(tf.random_normal([input_dim, output_dim], stddev=0.02), name=name.replace('_b', ''))
    b = tf.Variable(tf.zeros([output_dim], tf.float32), name=name.replace('_W', ''))
    return W, b


tmp = 256


class GAN(object):
    def __init__(self, lr=LEARNING_RATE, batch_size=BATCH_SIZE, z_dim=Z_DIM):
        self.lr = lr
        self.batch_size = batch_size
        self.z_dim = z_dim

        # 生成器的权重
        self.gen_W1, self.gen_b1 = get_W_b(z_dim, tmp, 'gen_W_b_1')
        self.gen_W2, self.gen_b2 = get_W_b(tmp, IMAGE_H * IMAGE_W, 'gen_W_b_2')

        # 判别器的权重
        self.discrim_W1, self.discrim_b1 = get_W_b(IMAGE_H * IMAGE_W, tmp, 'discrim_W_b_1')
        self.discrim_W2, self.discrim_b2 = get_W_b(tmp, 1, 'discrim_W_b_2')

    # 判别器
    def discriminator(self, x):
        d_h1 = tf.nn.relu(tf.add(tf.matmul(x, self.discrim_W1), self.discrim_b1))
        d_h2 = tf.add(tf.matmul(d_h1, self.discrim_W2), self.discrim_b2)
        return tf.nn.sigmoid(d_h2)

    # 生成器
    def generator(self, z):
        g_h1 = tf.nn.relu(tf.add(tf.matmul(z, self.gen_W1), self.gen_b1))
        g_h2 = tf.add(tf.matmul(g_h1, self.gen_W2), self.gen_b2)
        return tf.nn.sigmoid(g_h2)

    # 建立模型
    def build_model(self):
        z_sample = np.random.uniform(-1., 1., size=[self.batch_size, self.z_dim]).astype('float32')
        g_image = self.generator(z_sample)

        d_real = self.discriminator(x_in)
        d_fake = self.discriminator(g_image)

        d_cost = -tf.reduce_mean(tf.log(d_real) + tf.log(1. - d_fake))
        g_cost = -tf.reduce_mean(tf.log(d_fake))
        return d_cost, g_cost, tf.reduce_mean(d_real), tf.reduce_mean(d_fake)


# 画图
def plot_grid(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(IMAGE_H, IMAGE_W), cmap='Greys_r')
    return fig


# 训练
def train():
    with tf.Session() as sess:
        gan = GAN()
        discrim_vars = list(filter(lambda x: x.name.startswith('discrim'), tf.trainable_variables()))
        gen_vars = list(filter(lambda x: x.name.startswith('gen'), tf.trainable_variables()))

        d_cost, g_cost, d_real, d_fake = gan.build_model()

        optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
        d_opt = optimizer.minimize(d_cost, var_list=discrim_vars)
        g_opt = optimizer.minimize(g_cost, var_list=gen_vars)

        saver = tf.train.Saver()
        checkpoint = tf.train.latest_checkpoint(model_dir)
        if checkpoint:
            saver.restore(sess, checkpoint)  # 从模型中读取数据
            print('checkpoint: {}'.format(checkpoint))
        else:
            # 变量初始化
            sess.run(tf.global_variables_initializer())

        print("Started training {}".format(datetime.now().isoformat()[11:]))
        plot_index = 0

        for step in range(100000):
            batch_x, _ = mnist.train.next_batch(BATCH_SIZE)
            sess.run(d_opt, feed_dict={x_in: batch_x})
            sess.run(g_opt, feed_dict={x_in: batch_x})
            # 每1000个step保存一次图片
            if step % 1000 == 0:
                batch_x, _ = mnist.train.next_batch(BATCH_SIZE)
                d_cost_, d_real_, d_fake_ = sess.run([d_cost, d_real, d_fake], feed_dict={x_in: batch_x})
                g_cost_ = sess.run(g_cost, feed_dict={x_in: batch_x})
                print("step:{} Discriminator Loss {} Generator loss {}  d_real:{}  d_feak:{}".format(step, d_cost_,
                                                                                                     g_cost_, d_real_,
                                                                                                     d_fake_))

                z_sample = np.random.uniform(-1., 1., size=[16, Z_DIM]).astype('float32')
                g_image = sess.run(gan.generator(z_sample))
                fig = plot_grid(g_image)
                plt.savefig('D:\project\生成对抗网络\img\{}.png'.format(str(plot_index).zfill(4)), bbox_inches='tight')
                plot_index += 1
                plt.close(fig)

                # 保存模型
                saver.save(sess, "{}/model_gan.model".format(model_dir), global_step=step)

        print("Ended training {}".format(datetime.now().isoformat()[11:]))


if __name__ == "__main__":
    train()

猜你喜欢

转载自blog.csdn.net/shuishou07/article/details/78793653