Gan 网络生成图片
本次文章使用 pytorch 框架对 Gan网络生产图片网络架构进行编写与实现, Gan网络架构自己可以自定义, 这里采用自己定义的Gan网络。(完整代码下面链接)
- https://github.com/xiaoaleiBLUE/computer_vision(希望 starred一下)
文章目录
一、Gan网络
- GAN 应用: 数据生成, 图像翻译, 超分辨率(更高清), 图像补全。
- Adversarial: 对抗对手的意思, 两个模型: Generator(生成模型)
- Discriminator(判别模型, 分类模型)
- 创作者(G)目标: 赝品骗过鉴别者 鉴别者(D)目标: 火眼金睛不被骗
二、Gan网络架构
- Gan架构: Generator(生成模型), Discriminator(判别模型, 分类模型), 以下简称 G网络,D网络。
- Gan 网络架构图:
三、Gan网络实现基本思路
四、本文章使用Gan网络架构
4.1 G, D网络架构
- G网络架构(CxHxW通道排列形式) 100x1x1的tensor --> 512x4x4 --> 128x16x16 --> 64x32x32 --> 3x64x64
- D网络架构(CxHxW通道排列形式) 3x64x64的三通道图像或者tensor --> 64x32x32 --> 128x16x16 --> 256x8x8 --> 512x4x4 --> 8192 --> 1
4.2 ConvTranspose2d计算公式
- out = (input-1) x stride - 2 x padding + kernel_size + out_padding
- 比如G网络的 100 x 1 x 1的张量(tensor) --> 512 x 4 x 4
- out = (1-1) x 1 - 2 x 0 + 4 + 0 = 4, 所以 H,W 为 4
4.3 G, D网络架构代码实现
- G 网络, 生成者网络
# G 生成网络, 生成者网络
class G_model(nn.Module):
def __init__(self):
super(G_model, self).__init__()
self.main = nn.Sequential(
# 100*1*1(张量) ----> 4*4*512
# out = (input-1)*stride - 2*padding + kernel_size + out_padding
# out = (1-1)*1 - 2*0 + 4 + 0 = 4
nn.ConvTranspose2d(in_channels=100, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False),
# BN
nn.BatchNorm2d(512),
# Relu
nn.ReLU(inplace=True),
# 4*4*512 ----> 8*8*256
# out = (4-1)*2 - 2*1 + 4 + 0 = 8
nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
# BN
nn.BatchNorm2d(256),
# Relu
nn.ReLU(inplace=True),
# 8*8*256 ----> 16*16*128
nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
# BN
nn.BatchNorm2d(128),
# Relu
nn.ReLU(inplace=True),
# 16*16*128 ----> 32*32*64
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
# BN
nn.BatchNorm2d(64),
# Relu
nn.ReLU(inplace=True),
# 32*32*64 ----> 64*64*3
nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
# tanh()
nn.Tanh()
)
def forward(self, x):
return self.main(x)
- D 网络, 判别者网络
class D_model(nn.Module):
def __init__(self):
super(D_model, self).__init__()
self.main = nn.Sequential(
# 64*64*3 ----> 32*32*64
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=True),
# Relu
nn.LeakyReLU(0.2, inplace=True),
# 32*32*64 ----> 16*16*128
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=True),
# BN
nn.BatchNorm2d(128),
# Relu
nn.LeakyReLU(0.2, inplace=True),
# 16*16*128 ----> 8*8*256
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=True),
# BN
nn.BatchNorm2d(256),
# Relu
nn.LeakyReLU(0.2, inplace=True),
# 8*8*256 ----> 4*4*512=8192
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=True),
# BN
nn.BatchNorm2d(512),
# Relu
nn.LeakyReLU(0.2, inplace=True),
# flatten
nn.Flatten(),
# 全连接
nn.Linear(8192, 1),
# sigmoid
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
4.4 数据集相关操作
- 文件夹读取图片 --> dataset --> dataloader
# 导入数据
img_size = 64
img_preprocess = transforms.Compose([
# 缩放
transforms.Resize(img_size),
# 中心裁剪 64*64 的正方形
transforms.CenterCrop(img_size),
# PIL图像转为tensor H*W*C ---> C*H*W
transforms.ToTensor(),
# 归一化到[-1,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 从文件夹读取图片
dataset = tdst.ImageFolder(root='./data/', transform=img_preprocess)
# 加载成 dataloader
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
# 查看批次数量 12324 / 128 = 96.28, 向上取整
print(len(dataloader))
- 显示dataloader 中的图片
# 显示图片, 其中 x[0]:对应的图片,有shape, x[1]:对应类别标签, x[0].shape:torch.Size([128, 3, 64, 64]) 128对应批次大小
for x in dataloader:
# 设置画布大小
fig = plt.figure(figsize=(8, 8))
for i in range(16):
plt.subplot(4, 4, i+1)
# 转为 numpy, x[0][i].shape: torch.Size([3, 64, 64])
img = x[0][i].numpy()
# 通道顺序变换, 调整通道顺序为 PIL 格式
img = np.transpose(img, (1, 2, 0))
# 先转到[0,1],再乘以255
img = (img + 1) / 2 * 255
# 取整
img = img.astype('int')
# 显示
plt.imshow(img)
plt.axis('off')
plt.show()
break
4.5 优化器和损失函数
- 对G网络, D网络分别定义损失函数和优化器
# loss 损失函数, 二分类的损失函数
loss_fn = nn.BCELoss()
# 优化器
D_optimizer = optim.Adam(D_model.parameters(), lr=0.0002, betas=(0.5, 0.999))
G_optimizer = optim.Adam(G_model.parameters(), lr=0.0002, betas=(0.5, 0.999))
五、训练
- 按照上述Gan基本思路进行训练
# 开始训练
Epoch_num = 100
for epoch in range(Epoch_num):
# 获取批次图像
start_time = time.time()
for i, data in enumerate(dataloader):
# 训练判别网络D:真实数据标记为1
# 每次update前清空梯度
D_model.zero_grad()
# 获取数据, data[0]是图片, data[1]是类别
imgs_batch = data[0].to(device)
# 动态获取图片的batch_size
b_size = imgs_batch.size(0)
# 计算输出
output = D_model(imgs_batch).view(-1)
# 构建全1向量 label
ones_label = torch.full((b_size, ), 1, dtype=torch.float, device=device)
# 计算 loss
d_loss_real = loss_fn(output, ones_label)
# 反向传播
d_loss_real.backward()
# 梯度更新
D_optimizer.step()
# 训练判别网络D:假数据标记为0
# 清楚梯度
D_model.zero_grad()
# 构建随机张量
noise_tensor = torch.randn(b_size, 100, 1, 1, device=device)
# 生成假的图片
generated_imgs = G_model(noise_tensor)
# 假图片的输出, 此时不需要训练 G,
output = D_model(generated_imgs.detach()).view(-1)
# 构建全 0 向量
zeros_label = torch.full((b_size, ), 0, dtype=torch.float, device=device)
# 计算 loss
d_loss_fake = loss_fn(output, zeros_label)
# 反向传播
d_loss_fake.backward()
# 梯度更新
D_optimizer.step()
# 训练生成网络G:假数据标记为1
# 清楚梯度
G_model.zero_grad()
# 随机张量
noise_tensor = torch.randn(b_size, 100, 1, 1, device=device)
# 生成假的图片
generated_imgs = G_model(noise_tensor)
# 假图片的输出
output = D_model(generated_imgs).view(-1)
# 构建全1向量
ones_label = torch.full((b_size, ), 1, dtype=torch.float, device=device)
# 计算 loss
g_loss = loss_fn(output, ones_label)
# 反向传播
g_loss.backward()
# 生成网络梯度更新
G_optimizer.step()
# 打印训练时间
print('第{}个epoch所用时间: {}s'.format(epoch, time.time() - start_time))
# 每一个 epoch 输出结果
# 用 no_grad 表示梯度不跟踪
with torch.no_grad():
# 生成 16 个随机向量
fixed_noise = torch.randn(16, 100, 1, 1, device=device)
# 生成图片
fake_imgs = G_model(fixed_noise).detach().cpu().numpy()
# 画布大小
fig = plt.figure(figsize=(10, 10))
for i in range(fake_imgs.shape[0]):
plt.subplot(4, 4, i+1)
img = np.transpose(fake_imgs[i], (1, 2, 0))
img = (img + 1) / 2 * 255
img = img.astype('int')
plt.imshow(img)
plt.axis('off')
plt.show()