这个教程我们用MNIST数据集来训练成生对抗网格(GAN)。MNIST是28x28像素手写数字图像的大的集合。我们将训练网格来产生新的手写数字图像。
In [ ]:
!curl -Lo conda_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import conda_installer
conda_installer.install()
!/root/miniconda/bin/conda info -e
In [ ]:
!pip install --pre deepchem
import deepchem
deepchem.__version__
开始,我们需要导入所有我们需要的库并加载数据集(数据集来自Tensorflow)
In [1]:
import deepchem as dc
import tensorflow as tf
from deepchem.models.optimizers import ExponentialDecay
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Reshape
import matplotlib.pyplot as plot
import matplotlib.gridspec as gridspec
%matplotlib inline
mnist = tf.keras.datasets.mnist.load_data(path='mnist.npz')
images = mnist[0][0].reshape((-1, 28, 28, 1))/255
dataset = dc.data.NumpyDataset(images)
我们来看一下图像是什么样子的。
In [2]:
def plot_digits(im):
plot.figure(figsize=(3, 3))
grid = gridspec.GridSpec(4, 4, wspace=0.05, hspace=0.05)
for i, g in enumerate(grid):
ax = plot.subplot(g)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(im[i,:,:,0], cmap='gray')
plot_digits(images)
现在我们来创建自已的GAN。像上一个教程一样,它包含两部分:
1.生成器以随机噪音为输入,产生与训练数据相似的输出。
2.分判器以一些样本作为输入(可能是训练样本也可能是生成器产生的样本),并尽量分判真假。
这次我们使用不同风格的GAN叫做Wasserstein GAN (或简称WGAN)。很多情况下,它们被发现能产生比条件GAN更好的结果。这两者的主要区别在于分判器(本文叫"critic")。不是输出样本为真实训练数据的概率,它尽量学习如何测量训练分布与生成分布的距离。这种测量然后被用作损失函数来训练生成器。
我们使用非常简单的模型。生成器变换输入噪音到有8个通道的7x7图像。随后有两个卷积层,先上取样到14x14,最后到28x28。
分判器做的事情大致相同只是反过来。两个卷积屋下取样图像到14x14,然后到7x7。最后一个全链接层产生数字作为输出。上一个教程我们使用sigmoid激活函数,产生0到1之间的数,可以被解释为概率。因为这是一个WGAN,我们使用softplus激活函数。它产生不平衡的正数可被解释为距离。
In [3]:
class DigitGAN(dc.models.WGAN):
def get_noise_input_shape(self):
return (10,)
def get_data_input_shapes(self):
return [(28, 28, 1)]
def create_generator(self):
return tf.keras.Sequential([
Dense(7*7*8, activation=tf.nn.relu),
Reshape((7, 7, 8)),
Conv2DTranspose(filters=16, kernel_size=5, strides=2, activation=tf.nn.relu, padding='same'),
Conv2DTranspose(filters=1, kernel_size=5, strides=2, activation=tf.sigmoid, padding='same')
])
def create_discriminator(self):
return tf.keras.Sequential([
Conv2D(filters=32, kernel_size=5, strides=2, activation=tf.nn.leaky_relu, padding='same'),
Conv2D(filters=64, kernel_size=5, strides=2, activation=tf.nn.leaky_relu, padding='same'),
Dense(1, activation=tf.math.softplus)
])
gan = DigitGAN(learning_rate=ExponentialDecay(0.001, 0.9, 5000))
现在来训练它。就像上一个教程,我们写一个生成器来产生数据。这次数据来自数据集,我们用数据迭代表100次。
另一个不同点并不重要。训练传统的GAN时,重要的是保持生成器和分判器在整个训练过程的平衡。任意一个走得过快,另一个就会很难学习。
WGANs不会有这个问题。事实上,分判器越好,它给出的信号越清晰,它就越容易被生成器学习。因此我们指定generator_steps=0.2以至它仅采取一步训练生成器每五步训练分判器。这趋于产生更快的训练和更好的结果。
In [4]:
def iterbatches(epochs):
for i in range(epochs):
for batch in dataset.iterbatches(batch_size=gan.batch_size):
yield {gan.data_inputs[0]: batch[0]}
gan.fit_gan(iterbatches(100), generator_steps=0.2, checkpoint_interval=5000)
Ending global_step 4999: generator average loss 0.340072, discriminator average loss -0.0234236
Ending global_step 9999: generator average loss 0.52308, discriminator average loss -0.00702729
Ending global_step 14999: generator average loss 0.572661, discriminator average loss -0.00635684
Ending global_step 19999: generator average loss 0.560454, discriminator average loss -0.00534357
Ending global_step 24999: generator average loss 0.556055, discriminator average loss -0.00620613
Ending global_step 29999: generator average loss 0.541958, discriminator average loss -0.00734233
Ending global_step 34999: generator average loss 0.540904, discriminator average loss -0.00736641
Ending global_step 39999: generator average loss 0.524298, discriminator average loss -0.00650514
Ending global_step 44999: generator average loss 0.503931, discriminator average loss -0.00563732
Ending global_step 49999: generator average loss 0.528964, discriminator average loss -0.00590612
Ending global_step 54999: generator average loss 0.510892, discriminator average loss -0.00562366
Ending global_step 59999: generator average loss 0.494756, discriminator average loss -0.00533636
TIMING: model fitting took 4197.860 s
Let's generate some data and see how the results look.
In [5]:
Plot_digits(gan.predict_gan_generator(batch_size=16))
不错,许多生成的图像看起来像手写数字。模型越大训练时间越长结果当然更好。
下载全文请到www.data-vision.net,技术联系电话13712566524