深度学习《CGAN模型》

一:介绍
CGAN全程是Conditional Generative Adversarial Network,回想一下,传统的GAN或者其他的GAN都是通过一堆的训练数据,最后训练出了G网络,随机输入噪声最后产生的数据是这些训练数据类别中之一,我们提前无法预测是那哪一个?

因此,我们有的时候需要定向指定生成某些数据,比如我们想让G生成飞机,数字9,等等的图片数据。

怎么做呢:
1:就是给网络的输入噪声数据增加一些类别上的信息,就是说给定某些类别条件下,生成指定的数据,所以输入数据会有一些变化;

2:然后在损失函数那里,我们目标不再是输出1/0,也就是不再是简单的输出真实和构造。当判定是真实数据的时候,还需要判定出是哪一类别的图片。一般使用one-hot表示。
在这里插入图片描述

上图表示,改变输入噪声数据,给z增加类别y信息,怎么增加呢,就是简单的维度拼接,y可以是一个one-hot向量,或者其他表达形式。对于真实数据x不做变化,只用y来获取D的输出结果。

判别器D最后也应该输出是哪个类别,并且按照类别最小化来训练,也就是希望D(X)尽可能接近y。

二:实例操作
拿MNIST数据练手

网络的结构什么的都没有改变,唯一变化的就是,生成的噪声z拼接上了数据的类别标签,D的输出是数据的类别的one-hot向量,而不仅仅是0/1.
详细代码如下:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import pickle
import copy

import matplotlib.gridspec as gridspec
from torchvision.utils import save_image
import os

# 定义展示图片的函数
def show_images(images):  # 定义画图工具
    print('images: ', images.shape)
    images = np.reshape(images, [images.shape[0], -1])
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))

    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([sqrtimg, sqrtimg]))
    return

def deprocess_img(img):
    out = 0.5 * (img + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out



# step 1: ===========================================加载数据
batch_size = 128
noise_dim = 100  # 噪声维度,还是选择100维度
label_dim = 10  # 标签维度,10个数字,10个维度
z_dimension = noise_dim + label_dim  # z dimension = 100 noise dim + 10 one-hot dim

transform_img = transforms.Compose([transforms.ToTensor()])
trainset = MNIST('./data', train=True, transform=transform_img, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)



# step 2: ===========================================定义模型
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Conv2d(1, 32, 5, stride=1, padding=2),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2)),

            nn.Conv2d(32, 64, 5, stride=1, padding=2),
            nn.LeakyReLU(0.2, True),
            nn.MaxPool2d((2, 2))
        )
        self.fc = nn.Sequential(
            nn.Linear(7 * 7 * 64, 1024),
            nn.LeakyReLU(0.2, True),
            nn.Linear(1024, 10),
            nn.Sigmoid()
        )

    def forward(self, x):  # x: [batch_size, 1, 28, 28]
        x = self.dis(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x  # [batch_size, 10]


class generator(nn.Module):
    def __init__(self, input_size, num_feature):
        super(generator, self).__init__()
        self.fc = nn.Linear(input_size, num_feature)  # 1*56*56
        self.gen = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(True),

            nn.Conv2d(1, 50, 3, stride=1, padding=1),
            nn.BatchNorm2d(50),
            nn.ReLU(True),

            nn.Conv2d(50, 25, 3, stride=1, padding=1),
            nn.BatchNorm2d(25),
            nn.ReLU(True),

            nn.Conv2d(25, 1, 2, stride=2),
            nn.Tanh()
        )

    def forward(self, x):  # x: [batch_size, 110]
        x = self.fc(x)
        x = x.view(x.size(0), 1, 56, 56)
        x = self.gen(x)
        return x  # [batch_size, 1, 28, 28]

# 实例化模型
D_Net = discriminator()
G_Net = generator(z_dimension, 3136)  # 1*56*56


# step 3: ===========================================定义优化器和损失函数
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D_Net.parameters(), lr=0.0003)
g_optimizer = optim.Adam(G_Net.parameters(), lr=0.0003)


# step 4: ===========================================开始训练
if __name__ == "__main__":
    iter_count = 0
    show_every = 100

    epoch = 100
    gepoch = 1
    for i in range(epoch):
        for (img, label) in trainloader:
            img = Variable(img)
            print(img.shape)

            # 生成 lable 的 one-hot 向量,且设置对应类别位置是 1
            labels_onehot = np.zeros((img.shape[0], label_dim))
            labels_onehot[np.arange(img.shape[0]), label.numpy()] = 1

            # 生成随机向量,也就是噪声z,带有标签信息
            z = Variable(torch.randn(img.shape[0], noise_dim))
            z = np.concatenate((z.numpy(), labels_onehot), axis=1)
            z = Variable(torch.from_numpy(z).float())

            # 真实数据标签和虚假数据标签,
            real_label = Variable(torch.from_numpy(labels_onehot).float())  # 真实label对应类别是为1
            fake_label = Variable(torch.zeros(img.shape[0], label_dim))  # 假的label全是为0

            # compute loss of real_img
            real_out = D_Net(img)  # 真实图片送入判别器D输出0~1
            d_loss_real = criterion(real_out, real_label)  # 得到loss

            # compute loss of fake_img
            fake_img = G_Net(z)  # 将向量放入生成网络G生成一张图片
            fake_out = D_Net(fake_img)  # 判别器判断假的图片
            d_loss_fake = criterion(fake_out, fake_label)  # 假的图片的loss

            # D bp and optimize
            d_loss = d_loss_real + d_loss_fake
            d_optimizer.zero_grad()  # 判别器D的梯度归零
            d_loss.backward()  # 反向传播
            d_optimizer.step()  # 更新判别器D参数

            # 生成器G的训练compute loss of fake_img
            for j in range(gepoch):
                fake_img = G_Net(z)  # 将向量放入生成网络G生成一张图片
                output = D_Net(fake_img)  # 经过判别器得到结果
                g_loss = criterion(output, real_label)  # 得到假的图片与真实标签的loss
                # bp and optimize
                g_optimizer.zero_grad()  # 生成器G的梯度归零
                g_loss.backward()  # 反向传播
                g_optimizer.step()  # 更新生成器G参数
                print("G")

            # 利用模型进行测试,指定按照顺序生成0~9的数字
            if (iter_count % show_every == 0):
                test_batch_size = 10
                test_label = torch.from_numpy(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
                labels_onehot = np.zeros((test_batch_size, label_dim))
                labels_onehot[np.arange(test_batch_size), test_label.numpy()] = 1

                # 生成随机向量,也就是噪声z,带有标签信息
                test_z = Variable(torch.randn(test_batch_size, noise_dim))
                test_z = np.concatenate((test_z.numpy(), labels_onehot), axis=1)
                test_z = Variable(torch.from_numpy(test_z).float())
                fake_img = G_Net(test_z)  # 将向量放入生成网络G生成一张图片

                # imgs_numpy = deprocess_img(fake_img.data.cpu().numpy())
                # show_images(imgs_numpy)
                # plt.show()
                real_images = deprocess_img(fake_img.data)
                save_image(real_images, 'D:/software/Anaconda3/doc/3D_Img/cgan/test_%d.png' % (iter_count))

            iter_count += 1
            print('iter_count: ', iter_count)

最后按照顺序生成0~9的图像效果还是很不错的。
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述

猜你喜欢

转载自blog.csdn.net/qq_29367075/article/details/109149211