使用 tf.keras 和 eager 实现 Pix2Pix
本教程使用 这篇文章 中提出的条件 GAN 演示图像到图像转换,使用这种技术,我们可以将黑白照片着色,将谷歌地图转换为谷歌地球等等。本教程中,我们将建筑立面转换为真实的建筑。
我们将使用 CMP Facade 数据集,为了让教程尽量简短,我们将使用数据集的一个 预处理拷贝,该拷贝由论文作者创建。
在一个 P100 GPU 上,每个周期花费大约 58 秒。
下面是模型训练 200 个周期后生成的输出。
导入 TensorFlow 并启用 eager execution
# 导入 TensorFlow >= 1.10 并启用 eager execution
import tensorflow as tf
tf.enable_eager_execution()
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import PIL
from IPython.display import clear_output
载入数据集
你可以原始或者预处理的数据集。如论文中提到的,我们对训练集进行了随机抖动和镜像处理。
- 随机抖动:将图像调整至 286 x 286,然后随机裁剪至 256 x 256
- 随机镜像:将图像随机水平翻转
path_to_zippath_to = tf.keras.utils.get_file('facades.tar.gz',
cache_subdir=os.path.abspath('.'),
origin='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz',
extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
Downloading data from https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz
30171136/30168306 [==============================] - 1s 0us/step
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def load_image(image_file, is_train):
image = tf.read_file(image_file)
image = tf.image.decode_jpeg(image)
w = tf.shape(image)[1]
w = w // 2
real_image = image[:, :w, :]
input_image = image[:, w:, :]
input_image = tf.cast(input_image, tf.float32)
real_image = tf.cast(real_image, tf.float32)
if is_train:
# 随机抖动
# 调整大小至 286 x 286 x 3
input_image = tf.image.resize_images(input_image, [286, 286],
align_corners=True,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize_images(real_image, [286, 286],
align_corners=True,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# 随机裁剪至 256 x 256 x 3
stacked_image = tf.stack([input_image, real_image], axis=0)
cropped_image = tf.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
input_image, real_image = cropped_image[0], cropped_image[1]
if np.random.random() > 0.5:
# 随机镜像
input_image = tf.image.flip_left_right(input_image)
real_image = tf.image.flip_left_right(real_image)
else:
input_image = tf.image.resize_images(input_image, size=[IMG_HEIGHT, IMG_WIDTH],
align_corners=True, method=2)
real_image = tf.image.resize_images(real_image, size=[IMG_HEIGHT, IMG_WIDTH],
align_corners=True, method=2)
# 将图像归一化至 [-1, 1]
input_image = (input_image / 127.5) - 1
real_image = (real_image / 127.5) - 1
return input_image, real_image
使用 tf.data 创建批量,map(预处理)和打乱数据集
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(lambda x: load_image(x, True))
train_dataset = train_dataset.batch(1)
test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(lambda x: load_image(x, False))
test_dataset = test_dataset.batch(1)
建立生成器和判别器模型
- 生成器
- 生成器的架构是改进的 U-Net。
- 编码器中的块是(Conv -> Batchnorm -> Leaky ReLU)
- 解码器中的块是(Transposed Conv(转置卷积) -> Batchnorm -> Dropout (应用于前三个块) -> ReLU)
- 编码器和解码器间有跳远连接(和 U-Net 相同)
- 判别器
- 判别器是 PatchGAN。
- 判别器中的块是(Conv -> Batchnorm -> Leaky ReLU)
- 最后一层后的输出形状是(batch_size, 30, 30, 1)
- 输出的每个 30x30 块将对输入图像的 70x70 部分进行分类(这种架构称为 PatchGAN)。
- 判别器接受两个输入。
- 输入图像和目标图像,目标图像应被分类为真。
- 输入图像和生成图像(生成器的输出),生成图像应被分类为假。
- 将两个输入进行拼接(
tf.concat([inp, tar], axis=-1)
)
- 通过生成器和判别器的输入的形状在代码的注释中。
OUTPUT_CHANNELS = 3
class Downsample(tf.keras.Model):
def __init__(self, filters, size, apply_batchnorm=True):
super(Downsample, self).__init__()
self.apply_batchnorm = apply_batchnorm
initializer = tf.random_normal_initializer(0., 0.02)
self.conv1 = tf.keras.layers.Conv2D(filters,
(size, size),
strides=2,
padding='same',
kernel_initializer=initializer,
use_bias=False)
if self.apply_batchnorm:
self.batchnorm = tf.keras.layers.BatchNormalization()
def call(self, x, training):
x = self.conv1(x)
if self.apply_batchnorm:
x = self.batchnorm(x, training=training)
x = tf.nn.leaky_relu(x)
return x
class Upsample(tf.keras.Model):
def __init__(self, filters, size, apply_dropout=False):
super(Upsample, self).__init__()
self.apply_dropout = apply_dropout
initializer = tf.random_normal_initializer(0., 0.02)
self.up_conv = tf.keras.layers.Conv2DTranspose(filters,
(size, size),
strides=2,
padding='same',
kernel_initializer=initializer,
use_bias=False)
self.batchnorm = tf.keras.layers.BatchNormalization()
if self.apply_dropout:
self.dropout = tf.keras.layers.Dropout(0.5)
def call(self, x1, x2, training):
x = self.up_conv(x1)
x = self.batchnorm(x, training=training)
if self.apply_dropout:
x = self.dropout(x, training=training)
x = tf.nn.relu(x)
x = tf.concat([x, x2], axis=-1)
return x
class Generator(tf.keras.Model):
def __init__(self):
super(Generator, self).__init__()
initializer = tf.random_normal_initializer(0., 0.02)
self.down1 = Downsample(64, 4, apply_batchnorm=False)
self.down2 = Downsample(128, 4)
self.down3 = Downsample(256, 4)
self.down4 = Downsample(512, 4)
self.down5 = Downsample(512, 4)
self.down6 = Downsample(512, 4)
self.down7 = Downsample(512, 4)
self.down8 = Downsample(512, 4)
self.up1 = Upsample(512, 4, apply_dropout=True)
self.up2 = Upsample(512, 4, apply_dropout=True)
self.up3 = Upsample(512, 4, apply_dropout=True)
self.up4 = Upsample(512, 4)
self.up5 = Upsample(256, 4)
self.up6 = Upsample(128, 4)
self.up7 = Upsample(64, 4)
self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS,
(4, 4),
strides=2,
padding='same',
kernel_initializer=initializer)
@tf.contrib.eager.defun
def call(self, x, training):
# x shape == (bs, 256, 256, 3)
x1 = self.down1(x, training=training) # (bs, 128, 128, 64)
x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)
x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)
x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)
x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)
x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)
x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)
x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)
x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)
x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)
x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)
x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)
x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)
x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)
x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)
x16 = self.last(x15) # (bs, 256, 256, 3)
x16 = tf.nn.tanh(x16)
return x16
class DiscDownsample(tf.keras.Model):
def __init__(self, filters, size, apply_batchnorm=True):
super(DiscDownsample, self).__init__()
self.apply_batchnorm = apply_batchnorm
initializer = tf.random_normal_initializer(0., 0.02)
self.conv1 = tf.keras.layers.Conv2D(filters,
(size, size),
strides=2,
padding='same',
kernel_initializer=initializer,
use_bias=False)
if self.apply_batchnorm:
self.batchnorm = tf.keras.layers.BatchNormalization()
def call(self, x, training):
x = self.conv1(x)
if self.apply_batchnorm:
x = self.batchnorm(x, training=training)
x = tf.nn.leaky_relu(x)
return x
class Discriminator(tf.keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
initializer = tf.random_normal_initializer(0., 0.02)
self.down1 = DiscDownsample(64, 4, False)
self.down2 = DiscDownsample(128, 4)
self.down3 = DiscDownsample(256, 4)
# 这里使用了填零,因为我们需要形状从 (batch_size, 32,32,256)
# 变成 (batch_size, 31,31,512)
self.zero_pad1 = tf.keras.layers.ZeroPadding2D()
self.conv = tf.keras.layers.Conv2D(512,
(4, 4),
strides=1,
kernel_initializer=initializer,
use_bias=False)
self.batchnorm1 = tf.keras.layers.BatchNormalization()
# 形状从 (batch_size, 31, 31, 512) 变化至 (batch_size, 30, 30, 1)
self.zero_pad2 = tf.keras.layers.ZeroPadding2D()
self.last = tf.keras.layers.Conv2D(1,
(4, 4),
strides=1,
kernel_initializer=initializer)
@tf.contrib.eager.defun
def call(self, inp, tar, training):
# 将输入和目标进行拼接
x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2)
x = self.down1(x, training=training) # (bs, 128, 128, 64)
x = self.down2(x, training=training) # (bs, 64, 64, 128)
x = self.down3(x, training=training) # (bs, 32, 32, 256)
x = self.zero_pad1(x) # (bs, 34, 34, 256)
x = self.conv(x) # (bs, 31, 31, 512)
x = self.batchnorm1(x, training=training)
x = tf.nn.leaky_relu(x)
x = self.zero_pad2(x) # (bs, 33, 33, 512)
# 这里不添加 sigmoid 激活函数
# 因为损失函数期待一个原始对数。
x = self.last(x) # (bs, 30, 30, 1)
return x
# 生成器和判别器的 call 函数使用 tf.contrib.eager.defun() 函数装饰
# 如果使用 defun,我们会得到一个性能加速 (大约 25 秒 / 周期)
generator = Generator()
discriminator = Discriminator()
定义损失函数和优化器
- 判别器损失
- 判别器损失接受两个输入:真实图像,生成图像。
real_loss
是 真实图像 和 1 数组(因为这些是真实图像)的 sigmoid 交叉熵损失。generated_loss
是 生成图像 和 0 数组(因为这些是假图像)的 sigmoid 交叉熵损失。total_loss
是real_loss
和generated_loss
之和。
- 生成器损失
- 生成图像和数组 1 的 sigmoid 交叉熵损失。
- 论文还包括了生成图像与目标图像之间的 L1 损失,即平均绝对误差。
- 这使得生成图像在结构上与目标图像相似。
- 计算总生成器损失的公式为:total_gen_loss = gan_loss + LAMBDA * l1_loss,LAMBDA = 100,该值由论文作者决定。
LAMBDA = 100
def discriminator_loss(disc_real_output, disc_generated_output):
real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output),
logits = disc_real_output)
generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(disc_generated_output),
logits = disc_generated_output)
total_disc_loss = real_loss + generated_loss
return total_disc_loss
def generator_loss(disc_generated_output, gen_output, target):
gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_generated_output),
logits = disc_generated_output)
# 平均绝对误差
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
total_gen_loss = gan_loss + (LAMBDA * l1_loss)
return total_gen_loss
generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)
discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)
检查点(基于对象保存)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
训练
- 从迭代数据集开始
- 生成器接受输入图像,并生成输出
- 判别器接受输入图像和生成图像作为第一个输入,第二个输入是输入图像和目标图像
- 下一步,计算生成器和判别器的损失
- 然后,计算损失关于生成器和判别器变量(输入)的梯度,并应用于优化器
生成图像
- 训练完成后,开始生成一些图像
- 将测试数据集中的图像输入生成器
- 生成器将输入图像转换为我们期望的输出
- 最后一步是绘制预测图像
EPOCHS = 200
def generate_images(model, test_input, tar):
# training=True 是有意为之的,因为在测试数据集中运行模型时,
# 我们需要批量统计数据。如果我们使用 training=False,
# 我们将从训练数据集中获得累积的统计信息
prediction = model(test_input, training=True)
plt.figure(figsize=(15,15))
display_list = [test_input[0], tar[0], prediction[0]]
title = ['Input Image', 'Ground Truth', 'Predicted Image']
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
# 获取 [0,1] 之间的像素值来绘图
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for input_image, target in dataset:
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_output = generator(input_image, training=True)
disc_real_output = discriminator(input_image, target, training=True)
disc_generated_output = discriminator(input_image, gen_output, training=True)
gen_loss = generator_loss(disc_generated_output, gen_output, target)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
generator_gradients = gen_tape.gradient(gen_loss,
generator.variables)
discriminator_gradients = disc_tape.gradient(disc_loss,
discriminator.variables)
generator_optimizer.apply_gradients(zip(generator_gradients,
generator.variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
discriminator.variables))
if epoch % 1 == 0:
clear_output(wait=True)
for inp, tar in test_dataset.take(1):
generate_images(generator, inp, tar)
# 每 20 个周期保存(检查点)模型
if (epoch + 1) % 20 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))
train(train_dataset, EPOCHS)
Time taken for epoch 1 is 129.104407787323 sec
(这个结果是我在 Colab 上跑的,训练速度较慢,因此只训练了 6 个周期)
恢复最近的检查点
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
在完整的测试集上进行测试
for inp, tar in test_dataset:
generate_images(generator, inp, tar)
(选了其中的三幅)