「这是我参与11月更文挑战的第13天,活动详情查看:2021最后一次更文挑战」
GAN直观理解
Ian Goodfellow 在首次提出GAN,使用了形象的比喻来介绍 GAN 模型:生成网络 G 的功能就是产生逼真的假钞试图欺骗鉴别器 D,鉴别器 D 通过学习真钞和生成器 G 生成的假钞来掌握钞票的鉴别方法。这两个网络在相互博弈中进行训练,直到生成器 G 产生的假钞使鉴别器 D 难以分辨。而DCGAN是使用卷积操作和反卷积操作来替代原始GAN中的全连接操作。
DCGAN网络结构
GAN 包含生成网络(Generator, G
)和判别网络(Discriminator, D
),其中 G
用于学习数据的真实分布, D
用于将 G
生成的数据与真实样本区分开。
生成网络
G
从先验分布
中采样潜变量
,通过 G 学习分布
,获得生成样本
。其中潜变量z的先验分布
可以假设为常见的分布。
判别网络
D
是一个二分类网络,它判断采样自真实数据分布
的数据
和采样自生成网络的生成的数据
,判别网络的训练数据集由
和
组成。真实样本
的标签标为1,生成网络产生的样本
标为0,通过最小化判别网络 D 的预测值与标签之间的误差来优化判别网络。
GAN训练目标
判别网络目标是分辨出真样本 与假样本 。它的目标是最小化预测值和真实值之间的交叉熵损失函数:
CE表示交叉熵损失函数CrossEntropy:
判别网络 D 的优化目标是:
把 转换为 :
对于生成网络 ,希望生成数据能够骗过判别网络 D,假样本 在判别网络的输出越接近真实的标签越好。即在训练生成网络时,希望判别网络的输出 越逼近 1 越好,最小化 与 1 之间的交叉熵损失函数:
把 转换为 :
其中 为生成网络 G 的参数。
在训练过程中迭代训练鉴别器和生成器。
DCGAN实现
使用cifar10
的训练集作为GAN训练集实现DCGAN。
数据加载
加载cifar10
的训练集,并对数据进行预处理
#批大小
batch_size = 64
(train_x,_),_ = keras.datasets.cifar10.load_data()
#数据归一化
train_x = train_x / (255. / 2) - 1
print(train_x.shape)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
复制代码
网络
网络由鉴别网络与生成网络构成
鉴别网络
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator,self).__init__()
filters = 64
self.conv1 = keras.layers.Conv2D(filters,4,2,'valid',use_bias=False)
self.bn1 = keras.layers.BatchNormalization()
self.conv2 = keras.layers.Conv2D(filters*2,4,2,'valid',use_bias=False)
self.bn2 = keras.layers.BatchNormalization()
self.conv3 = keras.layers.Conv2D(filters*4,3,1,'valid',use_bias=False)
self.bn3 = keras.layers.BatchNormalization()
self.conv4 = keras.layers.Conv2D(filters*8,3,1,'valid',use_bias=False)
self.bn4 = keras.layers.BatchNormalization()
#全局池化
self.pool = keras.layers.GlobalAveragePooling2D()
self.flatten = keras.layers.Flatten()
self.fc = keras.layers.Dense(1)
def call(self,inputs,training=True):
x = inputs
x = tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
x = tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))
x = tf.nn.leaky_relu(self.bn4(self.conv4(x),training=training))
x = self.pool(x)
x = self.flatten(x)
logits = self.fc(x)
return logits
复制代码
生成网络
class Generator(keras.Model):
def __init__(self):
super(Generator,self).__init__()
filters = 64
self.conv1 = keras.layers.Conv2DTranspose(filters*4,4,1,'valid',use_bias=False)
self.bn1 = keras.layers.BatchNormalization()
self.conv2 = keras.layers.Conv2DTranspose(filters*3,4,2,'same',use_bias=False)
self.bn2 = keras.layers.BatchNormalization()
self.conv3 = keras.layers.Conv2DTranspose(filters*1,4,2,'same',use_bias=False)
self.bn3 = keras.layers.BatchNormalization()
self.conv4 = keras.layers.Conv2DTranspose(3,4,2,'same',use_bias=False)
def call(self,inputs,training=False):
x = inputs
x = tf.reshape(x,(x.shape[0],1,1,x.shape[1]))
x = tf.nn.relu(x)
x = tf.nn.relu(self.bn1(self.conv1(x),training=training))
x = tf.nn.relu(self.bn2(self.conv2(x),training=training))
x = tf.nn.relu(self.bn3(self.conv3(x),training=training))
x = self.conv4(x)
x = tf.tanh(x)
return x
复制代码
网络训练
训练时可以训练鉴别器多次然后训练一次生成器
定义损失函数
def celoss_ones(logits):
# 计算属于与标签为1的交叉熵
y = tf.ones_like(logits)
loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
return tf.reduce_mean(loss)
def celoss_zeros(logits):
# 计算属于与标签为0的交叉熵
y = tf.zeros_like(logits)
loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
return tf.reduce_mean(loss)
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
# 计算鉴别器的损失函数
# 采样生成图片
fake_image = generator(batch_z, is_training)
# 判定生成图片
d_fake_logits = discriminator(fake_image, is_training)
# 判定真实图片
d_real_logits = discriminator(batch_x, is_training)
# 真实图片与1之间的误差
d_loss_real = celoss_ones(d_real_logits)
# 生成图片与0之间的误差
d_loss_fake = celoss_zeros(d_fake_logits)
# 合并误差
loss = d_loss_fake + d_loss_real
return loss
def g_loss_fn(generator, discriminator, batch_z, is_training):
#计算生成器的损失函数
# 采样生成图片
fake_image = generator(batch_z, is_training)
# 在训练生成网络时,需要迫使生成图片判定为真
d_fake_logits = discriminator(fake_image, is_training)
# 计算生成图片与1之间的误差
loss = celoss_ones(d_fake_logits)
return loss
复制代码
实例化网络及优化器
#定义超参数
#潜变量维度
z_dim = 100
#epoch大小
epochs = 300
#批大小
batch_size = 64
#学习率
lr = 0.0002
is_training = True
#实例化网络
discriminator = Discriminator()
discriminator.build(input_shape=(4,32,32,3))
discriminator.summary()
generator = Generator()
generator.build(input_shape=(4,z_dim))
generator.summary()
#实例化优化器
g_optimizer = keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
d_optimizer = keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
复制代码
训练
#统计损失值
d_losses = []
g_losses = []
for epoch in range(epochs):
for _,batch_x in enumerate(dataset):
batch_z = tf.random.normal([batch_size,z_dim])
with tf.GradientTape() as tape:
d_loss = d_loss_fn(generator,discriminator,batch_z,batch_x,is_training)
grads = tape.gradient(d_loss,discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(grads,discriminator.trainable_variables))
with tf.GradientTape() as tape:
g_loss = g_loss_fn(generator,discriminator,batch_z,is_training)
grads = tape.gradient(g_loss,generator.trainable_variables)
g_optimizer.apply_gradients(zip(grads,generator.trainable_variables))
复制代码
效果展示
训练测试,可以通过调整超参数来获得更好的效果。
训练 26 个 epoch 的效果: