手把手教你搭建一个GAN(生成对抗网络)模型

手把手教你搭建一个GAN(生成对抗网络)模型

这个教程用 Keras 实现一个简单的GAN (生成对抗网络)模型. 基于这个:mnistGAN项目,我们将用到Keras 这个python 库。

将用到的python packages:

  • keras: 神经网络模型相关
  • matplotlib: 画图
  • tensorflow: keras运行基础
  • tqdm: 进度条

Talk is cheap, show me the code.

1. 导入包:

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from keras.layers import Input  # input layer  
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers

2. 定义几个变量:

# Let Keras know that we are using tensorflow as our backend engine
os.environ["KERAS_BACKEND"] = "tensorflow"

# To make sure that we can reproduce the experiment and get the same results by same random output
np.random.seed(10)

# The dimension of our random noise vector. this code is from: sonictl_at_cnblogs.net
random_dim = 100

3. 准备数据:

这里要用到手写阿拉伯数字(MNIST)数据集。MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 介绍一下:

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

  • Training set images(包含 60,000 个样本)
  • Training set labels(包含 60,000 个标签)
  • Test set images(10,000 个样本)
  • Test set labels(包含 10,000 个标签)
    在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示. labels包含了相应的目标变量, 也就是手写数字的类标签(整数 0-9).
    如图:
    MNIST 数据集中的图

代码:

def load_minst_data():
    # load the data
    (x_train, y_train), (x_test, y_test) = mnist.load_data()    # y is the set of labels, mnist.load_data() is part of Keras and allows you to easily import the MNIST dataset
    # normalize our inputs to be in the range[-1, 1]
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have
    # 784 columns per row
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)

4. 生成器和判别器

现在要建立生成器和判别器,我们要用Adam 优化器来优化这2个网络。对于生成器和判别器中的神经网络,我们都用3个隐藏层,并用Leaky Relu函数来作为激活函数。这里还有个技巧,为了提升判别器在它没见过的数据上的鲁棒性,在里面加了一个DropOut 层

# You will use the Adam optimizer
def get_optimizer():
    return Adam(lr=0.0002, beta_1=0.5)

def get_generator(optimizer):
    generator = Sequential()
    generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(1024))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(784, activation='tanh'))
    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator

def get_discriminator(optimizer):
    discriminator = Sequential()
    discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(1, activation='sigmoid'))
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator

To be continued. you can read more content in the link below:
Read more: https://www.datacamp.com/community/tutorials/generative-adversarial-networks

.
.
.
.
.
.

猜你喜欢

转载自www.cnblogs.com/sonictl/p/12323843.html