深度学习《WGAN模型》

WGAN是一个对原始GAN进行重大改进的网络
在这里插入图片描述

主要是在如下方面做了改进
在这里插入图片描述

实例测试代码如下:

还是用我16张鸣人的照片搞一波事情,每一个上述的改进点,我再代码中都是用 Difference 标注的。

import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import RMSprop
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10
from pylab import plt
import os
import torchvision.datasets as datasets
from torchvision.utils import save_image


# 至于 WGAN和GAN的区别请全文搜索 Importment Difference 即可查看

# step 1: ========================================== 定义本程序运行需要的一些参数
class WGAN_Config:
    lr = 0.0001
    nz = 100  # noise dimension
    image_size = 64

    nc = 3  # chanel of img
    ngf = 64  # generator channel
    ndf = 64  # discriminator channel

    batch_size = 16
    max_epoch = 5000  # =1 when debug
    clamp_num = 0.01  # WGAN clip gradient

wgan_opt = WGAN_Config()


def deprocess_img(img):
    out = 0.5 * (img + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 3, wgan_opt.image_size, wgan_opt.image_size)
    return out



# step 2: ========================================== 老流程,加载数据集。
# data preprocess
transform = transforms.Compose([
    transforms.Resize(wgan_opt.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])

# dataset = CIFAR10(root='cifar10/', transform=transform, download=True)
# dataloader = t.utils.data.DataLoader(dataset, wgan_opt.batch_size, shuffle=True)

data_path = os.path.abspath("D:/software/Anaconda3/doc/3D_Naruto")
print (os.listdir(data_path))
# 请注意,在data_path下面再建立一个目录,存放所有图片,ImageFolder会在子目录下读取数据,否则下一步会报错。
dataset = datasets.ImageFolder(root=data_path, transform=transform)
dataloader = t.utils.data.DataLoader(dataset, batch_size=wgan_opt.batch_size, shuffle=True)





# step 3: ========================================== 定义WGAN的G网络和D网络的模型
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.netg = nn.Sequential(
            nn.ConvTranspose2d(wgan_opt.nz, wgan_opt.ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(wgan_opt.ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(wgan_opt.ngf * 8, wgan_opt.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(wgan_opt.ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(wgan_opt.ngf * 4, wgan_opt.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(wgan_opt.ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(wgan_opt.ngf * 2, wgan_opt.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(wgan_opt.ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(wgan_opt.ngf, wgan_opt.nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, imgs):
        out = self.netg(imgs)
        return out

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.netd = nn.Sequential(
            nn.Conv2d(wgan_opt.nc, wgan_opt.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(wgan_opt.ndf, wgan_opt.ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(wgan_opt.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(wgan_opt.ndf * 2, wgan_opt.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(wgan_opt.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(wgan_opt.ndf * 4, wgan_opt.ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(wgan_opt.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(wgan_opt.ndf * 8, 1, 4, 1, 0, bias=False),
            # Importment Difference 1: do not use sigmoid func here any more.
            # nn.Sigmoid()
        )

    def forward(self, imgs):
        out = self.netd(imgs)
        return out.view(imgs.shape[0])

netd = discriminator()
netg = generator()



# step 4: ========================================== 初始化两个网络的参数
# 这一步是新学习的。参数权重初始化过程
def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
        m.weight.data.normal_(1.0, 0.02)

netd.apply(weight_init)
netg.apply(weight_init)





# step 5: ========================================== 定义优化器,这里使用 RMSprop,不使用Adam
# 也推荐使用 SGD
# Importment Difference 2: Use RMSprop instead of Adam
# optimizer
optimizerD = RMSprop(netd.parameters(), lr=wgan_opt.lr)
optimizerG = RMSprop(netg.parameters(), lr=wgan_opt.lr)

# Importment Difference: No Log in loss
# criterion
# criterion = nn.BCELoss()




# step 6: ========================================== 开始训练了
# begin training
rand_noise = Variable(t.FloatTensor(wgan_opt.batch_size, wgan_opt.nz, 1, 1).normal_(0, 1))

iter_count = 0
# 将BCEloss 改为非log的loss,按照文章的记载,通常会使用直接同1和-1做比较
one = t.ones(wgan_opt.batch_size)
mone = -1 * one
for epoch in range(wgan_opt.max_epoch):
    for ii, data in enumerate(dataloader, 0):
        imgs = data[0]  # real image
        noise = Variable(t.randn(imgs.size(0), wgan_opt.nz, 1, 1))  # fake image
        print(imgs.shape)

        # Importment Difference 4: clip param for discriminator
        for parm in netd.parameters():
            parm.data.clamp_(-wgan_opt.clamp_num, wgan_opt.clamp_num)

        # ----- train discriminator network -----
        netd.zero_grad()

        output = netd(imgs)  # train netd with real img
        output.backward(one)  # 跟 1 进行比较

        fake_pic = netg(noise).detach()  # train netd with real img, 梯度在此截断,不要继续往前传播。
        output2 = netd(fake_pic)
        output2.backward(mone)  # 跟 -1 进行比较

        optimizerD.step()

        # ------ train generator later -------
        # we train the discriminator many times, and less train for generator.
        # train netd more times: because the better netd is the better netg will be
        if (ii + 1) % 1 == 0:
            netg.zero_grad()
            noise.data.normal_(0, 1)

            fake_pic = netg(noise)
            output = netd(fake_pic)

            output.backward(one)  # 跟 1 进行比较
            optimizerG.step()

        if iter_count % 50 == 0:
            rand_imgs = netg(rand_noise)
            rand_imgs = deprocess_img(rand_imgs.data)
            save_image(rand_imgs, 'D:/software/Anaconda3/doc/3D_Img/wgan2/test_%d.png' % (iter_count))

        iter_count = iter_count + 1
        print('iter_count: ', iter_count)

效果如下:
最后都是用随机噪音产生的图片,时间太长了,训练次数不太够啊。
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_29367075/article/details/109140777