Pytorch从0到1之生成对抗神经网络GAN——(9)

开篇

在计算机视觉方向我们介绍了不少基础网络了,今天介绍的这种又是计算机视觉方向的一个骨灰级网络——GAN。GAN又名生成对抗网络,其主要作用是图像生成,我们在用图像训练模型的时候需要大量的数据集。但是如果我们的数据集不够怎么办呢?我们可以利用数据增强的方法,对图像进行上下左右的翻转,做随即剪切,也可以自己生成图像。这个生成图像就会用到我们的GAN网络。
GAN网络之所叫对抗网络是因为其内部有两个编码器,一个generator和一个discriminator。一个用于编码生成图像,一个用于将图像解码。generator企图生成的图像足够像原始图像,企图以假乱真;而discriminator企图戳穿generator的把戏,将其精准辨别真伪。整个网络就在二者的博弈中生成了图像。discriminator主要是判别generator产生的编码和真实图像的解码是否相似,不断提高二者的相似度,最终生成了可以以假乱真的图像。
其实通过这个描述大家就可以意识到,这应该是一个最小最大问题或者是一个最大最小问题。因为discriminator拼命想区分二者,所以他应该让二者区别足够大;而generator拼命想效仿,所以他应该让二者区别足够小。
这里简单介绍一下GAN,详细介绍可以参考GAN原理学习。我们主要看代码实现。

GAN生成对抗网络

库的引入

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image~

设备的配置以及超参数的定义

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'sample'~

如果你了解GAN的话,你应该可以清楚这个latent size,他其实是我们在generator生成图像网络中的隐藏层特征尺寸。
图片的生成地址
我们最终要把生成的图片放到一个文件夹中,所以我们创建一个目录用于存储生成的图片

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)~

图像的处理和转换

transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], 
                         std=[0.5])])~

我们将像素点进行归一化,均值为0.5,方差也为0.5。
这里说明一下,由于我们所用的图像要经过灰度转化变为灰度图,所以这里的channel是1维,如果是彩色图,我们则有三个channels,需要对每一个channel都指定均值和方差
数据的引入和加载

mnist = torchvision.datasets.MNIST(root = '../../data/',
                                   train = True,
                                   transforms = transforms,
                                   download = True)

、
data_loader = torch.utils.data.DataLoader(dataset = mnist,batch_size = batch_size,
                                          shuffle = True)~

GAN网络是在对抗的过程中逐步完善逐步提高以假乱真的水平,因此我们不需要分为测试机和训练集,用一份统一的数据就可以了。
Generator和Discrimator的定义

# Discrimator
D = nn.Sequential(
    nn.Linear(image_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size,1),
    nn.Sigmoid())

# Generator
G = nn.Sequential(
    nn.Linear(latent_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size,image_size),
    nn.Tanh())

D = D.to(device)
G = G.to(device)~

D与G的结构很相似,区别在于他们的激活函数,在Discrimator中我们通常使用leakyrelu,最后一层使用sigmoid来生成概率。而generator中我们激活函数是relu,最后一层使用双曲正切函数,因为它只用于解码生成图像,不需要计算概率。
损失函数和优化器的定义

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr = 0.0002)
g_optimizer = torch.optim.Adam(G.parameters(),lr = 0.0002)~

这里说明一下,我们使用的BCELoss是二分交叉熵损失函数,这个损失函数具体的形式大家可以看前文的超链接中提到的公式,这里不做展开了。
辅助函数的定义

def denorm(x):
    out = (x + 1) / 2
    # 将out限制在0-1
    return out.clamp(0,1)
# 重置梯度
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

训练模型

total_step = len(data_loader)
for epoch in range(num_epochs):
    for i,(images,_) in enumerate(data_loader):
        images = images.reshape(batch_size,-1).to(device)
        # 为计算损失函数生成标签,真是标签是1,虚假标签是0
        real_labels = torch.ones(batch_size,1).to(device)
        fake_labels = torch.zeros(batch_size,1).to(device)

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # 这里的损失函数应该都是0,因为我们用真实图片去测试,损失函数一定是0
        # 我们的目的是将对的分到real中,错的分到fake中,所以要求2个损失
        outputs = D(images)
        d_loss_real = criterion(outputs,real_labels)
        real_score = outputs
        # Compute BCELoss using fake images
        # 这里的损失函数应该是1,因为我们用的是虚假图片,且为随机生成的码
        z = torch.randn(batch_size,latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs,fake_labels)
        fake_score = outputs

        # 反向传播优化
        d_loss = d_loss_fake + d_loss_real
        # 清空梯度
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
	
	# 训练生成器
       # 用虚假图片计算损失
        z = torch.randn(batch_size,latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)

        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        g_loss = criterion(outputs,real_labels)

        # 反向传播优化
        reset_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))

    # 保存真实图像
    if (epoch + 1) == 1:
        images = images.reshape(images.size(0),1,28,28)
        save_image(denorm(images),os.path.join(sample_dir,'real_images.png'))

    # 保存样本图像
    fake_images = fake_images.reshape(fake_images.size(0),1,28,28)
    save_image(denorm(fake_images),os.path.join(sample_dir,'fake_images-{}.png'.format(epoch+1)))

保存模型

torch.save(G.state_dict(),'G.cpkt')
torch.save(D.state_dict(),'D.cpkt')

总结

GAN是一种比较常用的生成图像或者是判断两个图像间差异的网络,应用较多而且还有很多变体,比如DCGAN或者是CGAN,大家如果感兴趣可以精读一下相关论文。好啦GAN就介绍到这里啦,下次我们说VAE变分自编码器。

原创文章 101 获赞 13 访问量 2316

猜你喜欢

转载自blog.csdn.net/weixin_44755413/article/details/105878245