生成对抗网络DCGAN

1.介绍

论文:Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

论文地址:https://arxiv.org/abs/1511.06434

DCGAN将CNN和原始的GAN结合到一起,生成模型和判别模型都运用了深度卷积神经网络的生成对抗网络,奠定之后几乎所有GAN的基本网络架构,极大地提升了原始GAN训练的稳定性以及生成结果质量。 

2.改进点

  • DCGAN的生成器和判别器都舍弃了CNN的池化层,判别器保留CNN的整体架构,生成器将卷积层替换成了反卷积层(ConvTranspose2d) 
  • 在判别器和生成器中使用了Batch Normalization 层,这有助于处理初始化不好导致的训练问题,加速模型训练,提升了训练的稳定性。 注意在生成器的输出层和判别器的输入层不使用BN层。 
  • 在生成器中除输出层使用Tanh()激活函数,其余层全部使用ReLU激活函数。 在判别器中,除输出层外所有层都使用LeakyReLU激活函数, 防止梯度稀疏。
  • 在生成器中除输出层使用Tanh()激活函数,其余层全部使用ReLU激活函数。 

3.结构图

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, utils
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
import tqdm

ROOT_TRAIN = r'D:\CNN\anime-faces'

train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)  # 加载训练集
dataloader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=256,
                                           shuffle=True,
                                           num_workers=0)


# 定义生成器,输入是长度为100的噪声(正态分布随机数)
# 输出为3*224*224的图片(tensor)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(100, 256*16*16)
        self.bn1 = nn.BatchNorm1d(256*16*16)
        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3, 3),
                                          stride=1,
                                          padding=1)  #128*56*56
        self.bn2 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)  # 64*112*112
        self.bn3 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 3,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)  # 3*224*224

    def forward(self, x): #x为噪声输入
        x = F.relu(self.linear1(x)) #100 -- 256*56*56
        x = self.bn1(x)
        x = x.view(-1, 256, 16, 16)
        x = F.relu(self.deconv1(x)) #256*56*56 -- 128*56*56
        x = self.bn2(x)
        x = F.relu(self.deconv2(x)) #128*56*56 -- 64*112*112
        x = self.bn3(x)
        x = torch.tanh(self.deconv3(x)) #64*112*112 -- 3*224*224 生成器的输出不使用bn层
        return x


# 定义判别器,输入为3*224*224的图片,输出为二分类概率值
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*15*15, 1)

    def forward(self, x):
        x = F.dropout2d(F.leaky_relu(self.conv1(x)), p=0.3)  #64*111*111 判别器的输入不使用bn层
        x = F.dropout2d(F.leaky_relu(self.conv2(x)), p=0.3)  #128*55*55
        x = self.bn(x)
        x = x.view(-1, 128*15*15) #展平
        x = torch.sigmoid(self.fc(x))
        return x


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gen = Generator().to(device)
dis = Discriminator().to(device)

# 判别器优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4) #通过减小判别器的学习率降低其能力
# 生成器优化器
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-3)

loss_fn = torch.nn.BCELoss() # 二元交叉熵损失

# 绘图函数,将每一个epoch中生成器生成的图片绘制
def gen_img_plot(model, epoch, test_input): # model为Generator/Discriminator,test_input代表生成器输入的随机数
    # prediction = np.squeeze(model(test_input).detach().cpu().numpy()) #squeeze为去掉通道维度
    prediction = model(test_input).permute(0, 2, 3, 1).cpu().numpy() #将通道维度放在最后
    plt.figure(figsize=(10, 10))
    for i in range(prediction.shape[0]): #prediction.shape[0]=test_input的batchsize
        plt.subplot(2, 2, i + 1)
        plt.imshow((prediction[i]+1)/2) #从-1~1 --> 0~1
        plt.axis('off')
    plt.savefig('./face_DCGAN/image_GAN_{}.png'.format(epoch))
    # if epoch == 99:
    #     plt.show()

test_input = torch.randn(4, 100, device=device) #测试输入:16个长度为100的随机数


# DCGAN训练
D_loss = []
G_loss = []

for epoch in range(100):
    d_epoch_loss = 0 #判别器损失
    g_epoch_loss = 0 #生成器损失
    count = len(dataloader) #len(dataloader)返回批次数
    count1 = len(train_dataset) #len(train_dataset)返回样本数
    for step, (img, _) in enumerate(tqdm.tqdm(dataloader)):
        img = img.to(device)
        size = img.size(0) #该批次包含多少张图片
        random_noise = torch.randn(size, 100, device=device) #创建生成器的噪声输入

        d_optim.zero_grad() #判别器梯度清0
        real_output = dis(img) #将真实图像放到判别器上进行判断,得到对真实图像的预测结果
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output)) #得到判别器在真实图像上的损失
        d_real_loss.backward() #计算梯度

        gen_img = gen(random_noise) #得到生成图像
        fake_output = dis(gen_img.detach()) #将生成图像放到判别器上进行判断,得到对生成图像的预测结果,detach()为截断梯度
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output)) #得到判别器在生成图像上的损失
        d_fake_loss.backward()  # 计算梯度

        d_loss = d_real_loss + d_fake_loss #判别器的损失包含两部分
        d_optim.step() #判别器优化

        # 生成器
        g_optim.zero_grad() #生成器梯度清零
        fake_output = dis(gen_img) #将生成图像放到判别器上进行判断
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output)) #此处希望生成的图像能被判定为1
        g_loss.backward()  # 计算梯度
        g_optim.step() #生成器优化


        with torch.no_grad(): # loss累加的过程不需要计算梯度
            d_epoch_loss += d_loss.item() #将每一个批次的损失累加
            g_epoch_loss += g_loss.item() #将每一个批次的损失累加

    with torch.no_grad():  # loss累加的过程不需要计算梯度
        g_epoch_loss /= count
        d_epoch_loss /= count
        D_loss.append(d_epoch_loss) #保存每一个epoch的平均loss
        G_loss.append(g_epoch_loss) #保存每一个epoch的平均loss
        gen_img_plot(gen, epoch, test_input)  # 每个epoch会生成一张图
        print('Epoch:', epoch)

    plt.figure(figsize=(10, 10))
    plt.plot(range(1, len(D_loss)+1), D_loss, label='D_loss')
    plt.plot(range(1, len(G_loss)+1), G_loss, label='G_loss')
    plt.xlabel('epoch')  # 横轴名称
    plt.legend()
    plt.savefig('loss.png')  # 保存图片


# if __name__ == '__main__':
#     x = torch.rand((4, 3, 224, 224))
#     model = Discriminator()
#     out = model(x)
#     print(out.shape)

在卡通人物数据集上的训练效果可视化

猜你喜欢

转载自blog.csdn.net/m0_56247038/article/details/130270514
今日推荐