Generative Adversarial Nets (Xiaobaixue GAN Series One)

Link to the original paper: https://arxiv.org/abs/1406.2661

Introduction

Core idea: Borrowing the "zero one" game method to train the generative network. So that the network can learn the probability of its distribution from the existing data.

The black dots in the above figure are the real data distribution, the green solid line is the data distribution learned by G (generator), and the blue dashed line is the boundary of the discriminator, that is, the judgment under the blue dashed line is the real data, The judgment above the dotted line is dummy data. It can be seen that at the beginning of training (a), the blue line can be divided into part of the true and false data. At (b), it is more obvious to divide the true and false data after D (discriminator) is trained once, and then to (c) Training G makes the data distribution of G approach the real data, and repeat this n times until (d) the data distribution of G has completely fitted the real data distribution, and D can no longer distinguish between true and false data.

Core process:

The game process between D and G is completed by optimizing V.

basic structure

Discriminator training:

The discriminator training is the same as the normal training process. I label the real data and the generated data, and then calculate the loss after the data passes through the discriminator and the label to optimize the discriminator. At this time, the value of V is increasing.

 

                                    

It can be seen that the value range of the result of D is (0, 1). When the distribution of G and the real data are the same, the result of D is 0.5.

Generator training:

The goal of the generator is to fit the actual data distribution as much as possible, so that the discriminator cannot judge the authenticity, and the loss is designed based on this.

From the above formula, we can get:

From the perceptual understanding of the above formula, when D is judged to be all true or all judged to be false, its value is all zero. In other cases, the value is negative, and the minimum value is when G completely fits the actual data. , The result of D is 0.5, which takes the minimum value:

                             

After introducing the concept of divergence:

                                

In practice, if -log4 is removed, when G and Data are completely fitted, loss is exactly equal to zero.

Overall

When a single optimization D is used, the value of V becomes larger, and when G is optimized, the value of V becomes smaller, forming a situation of mutual confrontation.

Code practice and results

Code implementation (reference: https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py )

Made a little modification myself, that is, train G twice and D again:

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        #真实数据打上标签
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
        #生成数据打上标签

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator 1
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        #看G有多厉害,可以用假的来蒙混过D

        g_loss.backward()
        optimizer_G.step()
        
        
        # -----------------
        #  Train Generator 2
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()



        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        #真判为真的能力
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        #假判为假的能力
        d_loss = (real_loss + fake_loss) / 2


        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

Test result with minist

 

Guess you like

Origin blog.csdn.net/fan1102958151/article/details/106268019