1 Introduction
This tutorial uses an example to introduce DCGANs. We will train a Generative Adversarial Network (GAN) to generate new celebrities after showing many pictures of real celebrities. Most of the code here comes from the dcgan implementation in pytorch/examples. This document will give a comprehensive introduction to the implementation and clarify how the model works and why. But don't worry, you don't need to know GAN beforehand, but you may need to spend some events to reason about what actually happened at the bottom. In addition, to help save time, it is best to use one GPU, or two. Let's start from the beginning.
2. Generative Adversarial Networks
2.1 What is GAN
GANs are a framework for DL (Deep Learning) models to capture the distribution of training data, so that we can generate new data from the same distribution. GANs were proposed by Ian Goodfellow in 2014 and were first described in the paper Generative Adversarial Nets. They consist of two different models, one is the generator and the other is the discriminator. The job of the generator is to generate "fake" images that look like training images; the job of the discriminator is to look at the image and output whether it is a real training image or a fake image from the generator. During training, the generator constantly tries to surpass the discriminator by producing better and better fake actions, and the discriminator is to better detect and accurately classify real and fake images. The balance of this game is when the generator produces perfect fake actions to make the fake image look like it comes from training data, and the discriminator always guesses that the generator output image is true or false with a 50% probability.
Now, we begin to define some of the symbols used in this tutorial.
Symbolic definition of the discriminator
Symbol definition of generator
Goodfellow paper
In theory, the solution to this minimax game is that if the input is real or false, the discriminator will randomly guess. However, the convergence theory of GAN is still under active research, and in fact models are not always trained to this point.
2.2 What is DCGAN
DCGAN is a direct extension of the above GAN. The difference is that it explicitly uses convolution and convolution transposed layers in the discriminator and generator respectively. It was first proposed by Radford et al. in the paper Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks. The discriminator consists of strided convolutionlayers, batch norm layers and LeakyReLU activations. It inputs a 3x64x64 image, and then outputs a scalar probability that the input is from the actual data distribution. The generator is composed of convolutional-transpose layers, batch norm layers and ReLU activations. Its input is a latent vector drawn from a standard normal distribution, and its output is a 3x64x64 RGB image. strided conv-transpose layers allow latent scalars to be transformed into volumes with the same shape as the image. In this article, the author also provides some tips on how to set up the optimizer, how to calculate the loss function, and how to initialize the model weights, all of which will be explained in later chapters.
Related Links:
Paper: https://arxiv.org/pdf/1511.06434.pdf
strided convolution layers:https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d
batch norm layers:https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm2d
LeakyReLU activations:https://pytorch.org/docs/stable/nn.html#torch.nn.LeakyReLU
convolutional-transpose layers:https://pytorch.org/docs/stable/nn.html#torch.nn.ConvTranspose2d
ReLU activations : https: //pytorch.org/docs/stable/nn.html#relu
from __future__ import print_function #%matplotlib inline import argparse import os import random import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data import torchvision .datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation from IPython.display import HTML # Set random for reproducibility seem manualSeed = 999 #manualSeed = random.randint(1, 10000) # If you want new results, you need this code print("Random Seed: ", manualSeed) random.seed(manualSeed) torch.manual_seed(manualSeed)
Output result:
Random Seed: 999
3. DCGAN implementation process
3.1 Input
Let's define the input data to run our tutorial:
dataroot: The path to the root directory of the data set. We will discuss the dataset in detail in the next section
workers: the number of worker threads that use DataLoader to load data
batch_size: The batch size used in training. The batch size in the DCGAN paper is 128
image_size: The space size of the image used for training. This implementation defaults to 64×64. If other sizes are required, the structure of D and G must be changed. For details, see here.
nc: The number of color channels in the input image. For color images, this is the parameter set to 3
nz: the length of the latent vector
ngf: related to the depth of the feature map carried by the generator
ndf: Set the depth of the feature map propagated through the discriminator
num_epochs: The number of training epochs to run. Long-term training may bring better results, but it also takes longer
lr: learning rate. As stated in the DCGAN paper, this number should be 0.0002
beta1: Beta1 hyperparameters for Adam optimizer. As stated in the paper, this number should be 0.5
ngpu: The number of GPUs available. If it is 0, the code will run in CPU mode. If this number is greater than 0, it will run on that number of GPUs
# The root directory of the data set dataroot = "data/celeba" # The number of worker threads that load data workers = 2 # The batch size during training batch_size = 128 # The space size of the training image. All images will be adjusted to this size using a transformer. image_size = 64 # The number of channels in the training image. For color images, this is 3 nc = 3 # The size of the latent vector z (for example: the size of the generator input) nz = 100 # The size of the feature map in the generator ngf = 64 # The size of the feature map in the discriminator ndf = 64 # The size of the training epochs num_epochs = 5 # The learning rate of the optimizer lr = 0.0002 # Beta1 super parameter for Adam optimizer beta1 = 0.5 # The number of available GPUs. Use 0 for CPU mode. ngpu = 1
3.2 Data
In this tutorial, we will use the Celeb-A Faces dataset, which can be downloaded from the link or Google Drive. The data set will be downloaded as a file named img_align_celeba.zip. After downloading, create a directory called celeba and unzip the zip file into that directory. Then, set the data object input in this note to the celeba directory you just created. The generated directory structure should be:
/ path / to / celeba -> img_align_celeba -> 188242.jpg -> 173822.jpg -> 284702.jpg -> 537394.jpg
This is an important step because we will use the ImageFolder dataset class, which requires subdirectories in the root folder of the dataset. Now, we can create a data set, create a data loader, set up the equipment to run, and finally visualize some training data.
# We can use the image folder data set as set. # Create a dataset dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) # Create the loader dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers) # Choose the one we run on equipment device = torch.device(" cuda:0" if (torch.cuda.is_available() and ngpu> 0) else "cpu") # Drawing part of our input image real_batch = next(iter(dataloader)) plt.figure(figsize=(8,8)) plt.axis("off") plt.title("Training Images") plt.imshow(np .transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
3.3 Implementation
By setting the input parameters and the prepared data set, we can now enter the real implementation steps. We will start with the weight initialization strategy, and then discuss the generator, discriminator, loss function and training loop in detail.
3.3.1 Weight initialization
In the DCGAN paper, the author pointed out that all model weights should be randomly initialized from a normal distribution, with mean = 0 and stdev = 0.02. The weights_init function takes the initialization model as input and reinitializes all convolutions, convolution transpose and batch normalization layers to meet this standard. This function is applied to the model immediately after initialization.
# custom weights initialization called on netG and netD def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0)
3.3.2 Generator
The generator G is used to map the latent space vector (z) to the data space. Since our data is an image, converting to a data space means that we finally create an RGB image with the same size as the training image (ie 3x64x64). In fact, this is achieved through a series of stepped two-dimensional convolutional transposed layers, each of which is paired with a two-dimensional batch standard layer and relu activation. The output of the generator is input through the tanh function to return it to the input data in the range [-1,1]. It is worth noting that there is a batch norm function after the conversion layer, because this is the key contribution of the DCGAN paper. These layers contribute to the gradient flow during training. The image in the generator in the DCGAN paper is as follows:
Please note that how we set the input (nz, ngf and nc) will affect the generator architecture in the code. nz is the length of the input vector, ngf is related to the size of the feature map propagated through the generator, and nc is the number of channels in the output image (for RGB images, set to 3). Below is the code of the generator.
Generator code
# 生成 器 代码 class Generator (nn.Module): def __init __ (self, ngpu): super (Generator, self) .__ init __ () self.ngpu = ngpu self.main = nn.Sequential ( # 输入 是 Z , 进入 卷积 nn.ConvTranspose2d (nz, ngf * 8, 4, 1, 0, bias = False), nn.BatchNorm2d (ngf * 8), nn.ReLU (True), # state size. (Ngf * 8) x 4 x 4 nn.ConvTranspose2d (ngf * 8, ngf * 4, 4, 2, 1, bias = False), nn.BatchNorm2d (ngf * 4), nn.ReLU (True), # state size. (Ngf * 4) x 8 x 8 nn.ConvTranspose2d (ngf * 4, ngf * 2, 4, 2, 1, bias = False), nn.BatchNorm2d (ngf * 2), nn.ReLU (True), # state size. (ngf * 2) x 16 x 16 nn.ConvTranspose2d (ngf * 2, ngf, 4, 2, 1, bias = False), nn.BatchNorm2d (ngf), nn.ReLU (True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d (ngf, nc, 4, 2, 1, bias = False), nn.Tanh () # state size. (nc) x 64 x 64 ) def forward (self, input): return self.main (input)
Now, we can instantiate the generator and apply the weights_init function. View the printed model to see the structure of the generator object.
# Create generator netG = Generator(ngpu).to(device) # If needed, manage multi-gpu if (device.type =='cuda') and (ngpu> 1): netG = nn.DataParallel(netG, list (range(ngpu))) # Use weights_init function to initialize all weights randomly, mean = 0, stdev = 0.2. netG.apply(weights_init) # print model print(netG)
Output result:
Generator( (main): Sequential( (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU(inplace=True) (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (5): ReLU(inplace=True) (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (8): ReLU(inplace=True) (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (11): ReLU(inplace=True) (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (13): Tanh() ) )
3.3.3 Discriminator
As mentioned above, the discriminator D is a binary classification network that takes an image as input and outputs the scalar probability that the input image is true (as opposed to false). Here, D uses a 3x64x64 input image, processes it through a series of Conv2d, BatchNorm2d and LeakyReLU layers, and outputs the final probability through the Sigmoid activation function. If the problem requires, this architecture can be extended with more layers, but using strided convolution, BatchNorm and LeakyReLU is of great significance. The DCGAN paper mentions that using stride convolution instead of pooling to downsampling is a good practice because it allows the network to learn its own pooling function. The batch standard and leaky relu function also promote a good gradient flow, which is vital to the learning process of sum.
Discriminator code
class Discriminator(nn.Module): def __init__(self, ngpu): super(Discriminator, self).__init__() self.ngpu = ngpu self.main = nn.Sequential( # input is (nc) x 64 x 64 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf) x 32 x 32 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*2) x 16 x 16 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*4) x 8 x 8 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*8) x 4 x 4 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, input): return self.main(input)
Now, like the generator, we can create a discriminator, apply the weights_init function, and print the structure of the model.
# Create a discriminator netD = Discriminator(ngpu).to(device) # Handle multi-gpu if desired if (device.type =='cuda') and (ngpu> 1): netD = nn.DataParallel(netD, list( range(ngpu))) # Apply weights_init function to initialize all weights randomly, mean = 0, stdev = 0.2 netD.apply(weights_init) # print model print(netD)
Output result:
Discriminator( (main): Sequential( (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (1): LeakyReLU(negative_slope=0.2, inplace=True) (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): LeakyReLU(negative_slope=0.2, inplace=True) (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): LeakyReLU(negative_slope=0.2, inplace=True) (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False) (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (10): LeakyReLU(negative_slope=0.2, inplace=True) (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False) (12): Sigmoid() ) )
3.3.4 Loss function and optimizer
Through the D and G settings, we can specify how they learn through the loss function and optimizer. We will use the Binary Cross Entropy Loss (BCELoss) function defined in PyTorch:
# Initialize the BCELoss function criterion = nn.BCELoss() # Create a batch of potential vectors, we will use it to visualize the generator process fixed_noise = torch.randn(64, nz, 1, 1, device=device) # In training The practice of establishing true and false labels during the period real_label = 1 fake_label = 0 # Set up the Adam optimizer for G and D optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerG = optim. Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
3.3.4 Training
Finally, now that all parts of the GAN framework have been defined, we can train it. Please note that training GAN is an art form to some extent, because incorrect hyperparameter settings can cause models with little interpretation of errors to collapse. Here, we will pay close attention to Algorithm 1 in Goodfellow’s paper. Also follow some of the best practices shown in ganhacks. In other words, we will construct different mini-batches "for real and fake" images, and also adjust the objective function of G to maximize
Training is divided into two main parts, the first part updates the discriminator, and the second part updates the generator.
* Part 1: Training the discriminator
* Part 2: Update the discriminator
This step may take a while, depending on the number of epochs you run and whether some data has been deleted from the dataset.
# Training Loop # Lists to keep track of progress img_list = [] G_losses = [] D_losses = [] iters = 0 print("Starting Training Loop...") # For each epoch for epoch in range(num_epochs): # 对于数据加载器中的每个batch for i, data in enumerate(dataloader, 0): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### ## Train with all-real batch netD.zero_grad() # Format batch real_cpu = data[0].to(device) b_size = real_cpu.size(0) label = torch.full((b_size,), real_label, device=device) # Forward pass real batch through D output = netD(real_cpu).view(-1) # Calculate loss on all-real batch errD_real = criterion(output, label) # Calculate gradients for D in backward pass errD_real.backward() D_x = output.mean().item() ## Train with all-fake batch # Generate batch of latent vectors noise = torch.randn(b_size, nz, 1, 1, device=device) # Generate fake image batch with G fake = netG(noise) label.fill_(fake_label) # Classify all fake batch with D output = netD(fake.detach()).view(-1) # Calculate D's loss on the all-fake batch errD_fake = criterion(output, label) # Calculate the gradients for this batch errD_fake.backward() D_G_z1 = output.mean().item() # Add the gradients from the all-real and all-fake batches errD = errD_real + errD_fake # Update D optimizerD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost # Since we just updated D, perform another forward pass of all-fake batch through D output = netD(fake).view(-1) # Calculate G's loss based on this output errG = criterion(output, label) # Calculate gradients for G errG.backward() D_G_z2 = output.mean().item() # Update G optimizerG.step() # Output training stats if i % 50 == 0: print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) # Save Losses for plotting later G_losses.append(errG.item()) D_losses.append(errD.item()) # Check how the generator is doing by saving G's output on fixed_noise if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): with torch.no_grad(): fake = netG(fixed_noise).detach().cpu() img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) iters += 1
Output result:
Starting Training Loop... [0/5][0/1583] Loss_D: 2.0937 Loss_G: 5.2059 D(x): 0.5704 D(G(z)): 0.6680 / 0.0090 [0/5][50/1583] Loss_D: 0.3567 Loss_G: 12.2064 D(x): 0.9364 D(G(z)): 0.1409 / 0.0000 [0/5][100/1583] Loss_D: 0.3519 Loss_G: 8.8873 D(x): 0.8714 D(G(z)): 0.0327 / 0.0004 [0/5][150/1583] Loss_D: 0.5300 Loss_G: 6.6410 D(x): 0.8918 D(G(z)): 0.2776 / 0.0030 [0/5][200/1583] Loss_D: 0.2543 Loss_G: 4.3581 D(x): 0.8662 D(G(z)): 0.0844 / 0.0218 [0/5][250/1583] Loss_D: 0.7170 Loss_G: 4.2652 D(x): 0.8285 D(G(z)): 0.3227 / 0.0370 [0/5][300/1583] Loss_D: 0.5739 Loss_G: 4.2060 D(x): 0.8329 D(G(z)): 0.2577 / 0.0305 [0/5][350/1583] Loss_D: 0.8139 Loss_G: 6.5680 D(x): 0.9163 D(G(z)): 0.3844 / 0.0062 [0/5][400/1583] Loss_D: 0.4089 Loss_G: 5.0794 D(x): 0.8580 D(G(z)): 0.1221 / 0.0243 [0/5][450/1583] Loss_D: 0.4785 Loss_G: 4.1612 D(x): 0.7154 D(G(z)): 0.0514 / 0.0258 [0/5][500/1583] Loss_D: 0.3748 Loss_G: 4.2888 D(x): 0.8135 D(G(z)): 0.0955 / 0.0264 [0/5][550/1583] Loss_D: 0.5247 Loss_G: 5.9952 D(x): 0.8347 D(G(z)): 0.1580 / 0.0075 [0/5][600/1583] Loss_D: 0.7765 Loss_G: 2.2662 D(x): 0.5977 D(G(z)): 0.0408 / 0.1708 [0/5][650/1583] Loss_D: 0.6914 Loss_G: 4.4091 D(x): 0.6502 D(G(z)): 0.0266 / 0.0238 [0/5][700/1583] Loss_D: 0.5679 Loss_G: 5.3386 D(x): 0.8476 D(G(z)): 0.2810 / 0.0098 [0/5][750/1583] Loss_D: 0.3717 Loss_G: 5.1295 D(x): 0.9221 D(G(z)): 0.2207 / 0.0106 [0/5][800/1583] Loss_D: 0.4423 Loss_G: 3.1339 D(x): 0.8418 D(G(z)): 0.1655 / 0.0820 [0/5][850/1583] Loss_D: 0.3391 Loss_G: 4.8393 D(x): 0.7920 D(G(z)): 0.0315 / 0.0169 [0/5][900/1583] Loss_D: 0.4346 Loss_G: 4.3887 D(x): 0.8883 D(G(z)): 0.2270 / 0.0202 [0/5][950/1583] Loss_D: 0.5315 Loss_G: 4.6233 D(x): 0.8393 D(G(z)): 0.2490 / 0.0188 [0/5][1000/1583] Loss_D: 0.5281 Loss_G: 6.1465 D(x): 0.9643 D(G(z)): 0.3270 / 0.0049 [0/5][1050/1583] Loss_D: 0.5515 Loss_G: 6.4457 D(x): 0.9262 D(G(z)): 0.3361 / 0.0033 [0/5][1100/1583] Loss_D: 0.4430 Loss_G: 4.7469 D(x): 0.7306 D(G(z)): 0.0184 / 0.0202 [0/5][1150/1583] Loss_D: 0.7336 Loss_G: 2.6978 D(x): 0.6552 D(G(z)): 0.1293 / 0.1059 [0/5][1200/1583] Loss_D: 0.2927 Loss_G: 4.7480 D(x): 0.8858 D(G(z)): 0.1329 / 0.0173 [0/5][1250/1583] Loss_D: 2.0790 Loss_G: 5.1077 D(x): 0.2722 D(G(z)): 0.0036 / 0.0172 [0/5][1300/1583] Loss_D: 0.2431 Loss_G: 5.0027 D(x): 0.8812 D(G(z)): 0.0816 / 0.0169 [0/5][1350/1583] Loss_D: 0.2969 Loss_G: 4.6160 D(x): 0.9126 D(G(z)): 0.1609 / 0.0183 [0/5][1400/1583] Loss_D: 0.7158 Loss_G: 2.9825 D(x): 0.6117 D(G(z)): 0.0292 / 0.0900 [0/5][1450/1583] Loss_D: 0.7513 Loss_G: 1.9396 D(x): 0.6186 D(G(z)): 0.0559 / 0.2414 [0/5][1500/1583] Loss_D: 0.4366 Loss_G: 3.9122 D(x): 0.8736 D(G(z)): 0.2231 / 0.0325 [0/5][1550/1583] Loss_D: 0.3204 Loss_G: 4.2434 D(x): 0.8395 D(G(z)): 0.0929 / 0.0271 [1/5][0/1583] Loss_D: 0.5077 Loss_G: 4.8872 D(x): 0.9331 D(G(z)): 0.3082 / 0.0122 [1/5][50/1583] Loss_D: 0.5637 Loss_G: 3.6652 D(x): 0.8525 D(G(z)): 0.2684 / 0.0414 [1/5][100/1583] Loss_D: 0.4047 Loss_G: 3.6624 D(x): 0.8323 D(G(z)): 0.1508 / 0.0473 [1/5][150/1583] Loss_D: 0.3858 Loss_G: 3.3070 D(x): 0.7873 D(G(z)): 0.0826 / 0.0583 [1/5][200/1583] Loss_D: 0.4348 Loss_G: 3.6292 D(x): 0.8390 D(G(z)): 0.1908 / 0.0417 [1/5][250/1583] Loss_D: 0.5953 Loss_G: 2.1992 D(x): 0.6572 D(G(z)): 0.0649 / 0.1540 [1/5][300/1583] Loss_D: 0.4062 Loss_G: 3.8770 D(x): 0.8655 D(G(z)): 0.2012 / 0.0310 [1/5][350/1583] Loss_D: 0.9472 Loss_G: 1.4837 D(x): 0.4979 D(G(z)): 0.0322 / 0.2947 [1/5][400/1583] Loss_D: 0.5269 Loss_G: 2.6842 D(x): 0.9150 D(G(z)): 0.2922 / 0.1248 [1/5][450/1583] Loss_D: 0.6091 Loss_G: 3.8100 D(x): 0.8194 D(G(z)): 0.2720 / 0.0360 [1/5][500/1583] Loss_D: 0.5674 Loss_G: 3.2716 D(x): 0.8279 D(G(z)): 0.2452 / 0.0610 [1/5][550/1583] Loss_D: 0.8366 Loss_G: 5.5266 D(x): 0.9263 D(G(z)): 0.4840 / 0.0076 [1/5][600/1583] Loss_D: 0.6098 Loss_G: 2.2626 D(x): 0.6424 D(G(z)): 0.0640 / 0.1451 [1/5][650/1583] Loss_D: 0.3970 Loss_G: 3.4130 D(x): 0.8347 D(G(z)): 0.1613 / 0.0491 [1/5][700/1583] Loss_D: 0.5422 Loss_G: 3.1208 D(x): 0.7889 D(G(z)): 0.1972 / 0.0699 [1/5][750/1583] Loss_D: 0.9114 Loss_G: 1.3789 D(x): 0.5066 D(G(z)): 0.0350 / 0.3440 [1/5][800/1583] Loss_D: 1.1917 Loss_G: 5.6081 D(x): 0.9548 D(G(z)): 0.6084 / 0.0064 [1/5][850/1583] Loss_D: 0.4852 Loss_G: 1.9158 D(x): 0.7103 D(G(z)): 0.0636 / 0.1943 [1/5][900/1583] Loss_D: 0.5322 Loss_G: 2.8350 D(x): 0.7762 D(G(z)): 0.1994 / 0.0868 [1/5][950/1583] Loss_D: 0.7765 Loss_G: 1.7411 D(x): 0.5553 D(G(z)): 0.0732 / 0.2260 [1/5][1000/1583] Loss_D: 0.5518 Loss_G: 4.5488 D(x): 0.9244 D(G(z)): 0.3354 / 0.0161 [1/5][1050/1583] Loss_D: 0.4237 Loss_G: 3.2012 D(x): 0.8118 D(G(z)): 0.1651 / 0.0583 [1/5][1100/1583] Loss_D: 1.1245 Loss_G: 5.5327 D(x): 0.9483 D(G(z)): 0.5854 / 0.0090 [1/5][1150/1583] Loss_D: 0.5543 Loss_G: 1.9609 D(x): 0.6777 D(G(z)): 0.0933 / 0.1936 [1/5][1200/1583] Loss_D: 0.4945 Loss_G: 2.0234 D(x): 0.7580 D(G(z)): 0.1329 / 0.1742 [1/5][1250/1583] Loss_D: 0.5637 Loss_G: 2.9421 D(x): 0.7701 D(G(z)): 0.2123 / 0.0780 [1/5][1300/1583] Loss_D: 0.6178 Loss_G: 2.5512 D(x): 0.7828 D(G(z)): 0.2531 / 0.1068 [1/5][1350/1583] Loss_D: 0.4302 Loss_G: 2.5266 D(x): 0.8525 D(G(z)): 0.2053 / 0.1141 [1/5][1400/1583] Loss_D: 1.5730 Loss_G: 1.4042 D(x): 0.2854 D(G(z)): 0.0183 / 0.3325 [1/5][1450/1583] Loss_D: 0.6962 Loss_G: 3.3562 D(x): 0.8652 D(G(z)): 0.3732 / 0.0534 [1/5][1500/1583] Loss_D: 0.7635 Loss_G: 1.4343 D(x): 0.5765 D(G(z)): 0.0807 / 0.3056 [1/5][1550/1583] Loss_D: 0.4228 Loss_G: 3.3460 D(x): 0.8169 D(G(z)): 0.1671 / 0.0522 [2/5][0/1583] Loss_D: 0.8332 Loss_G: 1.5990 D(x): 0.6355 D(G(z)): 0.2409 / 0.2433 [2/5][50/1583] Loss_D: 0.4681 Loss_G: 2.0920 D(x): 0.7295 D(G(z)): 0.0978 / 0.1626 [2/5][100/1583] Loss_D: 0.7995 Loss_G: 2.8227 D(x): 0.7766 D(G(z)): 0.3675 / 0.0828 [2/5][150/1583] Loss_D: 0.3804 Loss_G: 2.6037 D(x): 0.8523 D(G(z)): 0.1729 / 0.1016 [2/5][200/1583] Loss_D: 0.9238 Loss_G: 0.8758 D(x): 0.5284 D(G(z)): 0.1343 / 0.4542 [2/5][250/1583] Loss_D: 0.5205 Loss_G: 2.6795 D(x): 0.7778 D(G(z)): 0.1875 / 0.0934 [2/5][300/1583] Loss_D: 0.7720 Loss_G: 3.8033 D(x): 0.9307 D(G(z)): 0.4405 / 0.0384 [2/5][350/1583] Loss_D: 0.5825 Loss_G: 3.3677 D(x): 0.9309 D(G(z)): 0.3609 / 0.0470 [2/5][400/1583] Loss_D: 0.4290 Loss_G: 2.5963 D(x): 0.7495 D(G(z)): 0.1047 / 0.0976 [2/5][450/1583] Loss_D: 0.7161 Loss_G: 4.0053 D(x): 0.8270 D(G(z)): 0.3655 / 0.0252 [2/5][500/1583] Loss_D: 0.5238 Loss_G: 2.3543 D(x): 0.8084 D(G(z)): 0.2320 / 0.1330 [2/5][550/1583] Loss_D: 0.7724 Loss_G: 2.2096 D(x): 0.6645 D(G(z)): 0.2238 / 0.1417 [2/5][600/1583] Loss_D: 0.4897 Loss_G: 2.8286 D(x): 0.7776 D(G(z)): 0.1738 / 0.0832 [2/5][650/1583] Loss_D: 1.2680 Loss_G: 4.7502 D(x): 0.8977 D(G(z)): 0.6179 / 0.0149 [2/5][700/1583] Loss_D: 0.7054 Loss_G: 3.3908 D(x): 0.8692 D(G(z)): 0.3753 / 0.0490 [2/5][750/1583] Loss_D: 0.4933 Loss_G: 3.6839 D(x): 0.8933 D(G(z)): 0.2845 / 0.0368 [2/5][800/1583] Loss_D: 0.6246 Loss_G: 2.7728 D(x): 0.8081 D(G(z)): 0.2968 / 0.0821 [2/5][850/1583] Loss_D: 1.2216 Loss_G: 1.1784 D(x): 0.3819 D(G(z)): 0.0446 / 0.3623 [2/5][900/1583] Loss_D: 0.6578 Loss_G: 1.7445 D(x): 0.6494 D(G(z)): 0.1271 / 0.2173 [2/5][950/1583] Loss_D: 0.8333 Loss_G: 1.2805 D(x): 0.5193 D(G(z)): 0.0543 / 0.3210 [2/5][1000/1583] Loss_D: 0.7348 Loss_G: 0.7953 D(x): 0.5920 D(G(z)): 0.1265 / 0.4815 [2/5][1050/1583] Loss_D: 0.6809 Loss_G: 3.7259 D(x): 0.8793 D(G(z)): 0.3686 / 0.0401 [2/5][1100/1583] Loss_D: 0.7728 Loss_G: 2.1345 D(x): 0.5886 D(G(z)): 0.1234 / 0.1626 [2/5][1150/1583] Loss_D: 0.9383 Loss_G: 3.7146 D(x): 0.8942 D(G(z)): 0.5075 / 0.0355 [2/5][1200/1583] Loss_D: 0.4951 Loss_G: 2.8725 D(x): 0.8084 D(G(z)): 0.2163 / 0.0764 [2/5][1250/1583] Loss_D: 0.6952 Loss_G: 2.1559 D(x): 0.6769 D(G(z)): 0.2063 / 0.1561 [2/5][1300/1583] Loss_D: 0.4560 Loss_G: 2.6873 D(x): 0.7993 D(G(z)): 0.1710 / 0.0908 [2/5][1350/1583] Loss_D: 0.9185 Loss_G: 3.9262 D(x): 0.8631 D(G(z)): 0.4938 / 0.0276 [2/5][1400/1583] Loss_D: 0.5935 Loss_G: 1.2768 D(x): 0.6625 D(G(z)): 0.1064 / 0.3214 [2/5][1450/1583] Loss_D: 0.8836 Loss_G: 4.0820 D(x): 0.9368 D(G(z)): 0.5101 / 0.0251 [2/5][1500/1583] Loss_D: 0.5268 Loss_G: 2.1486 D(x): 0.7462 D(G(z)): 0.1701 / 0.1450 [2/5][1550/1583] Loss_D: 0.5581 Loss_G: 3.0543 D(x): 0.8082 D(G(z)): 0.2489 / 0.0644 [3/5][0/1583] Loss_D: 0.6875 Loss_G: 2.3447 D(x): 0.7796 D(G(z)): 0.3180 / 0.1182 [3/5][50/1583] Loss_D: 0.7772 Loss_G: 1.2497 D(x): 0.5569 D(G(z)): 0.0763 / 0.3372 [3/5][100/1583] Loss_D: 1.8087 Loss_G: 0.8440 D(x): 0.2190 D(G(z)): 0.0213 / 0.4701 [3/5][150/1583] Loss_D: 0.6292 Loss_G: 2.8794 D(x): 0.8807 D(G(z)): 0.3623 / 0.0741 [3/5][200/1583] Loss_D: 0.5880 Loss_G: 2.2299 D(x): 0.8279 D(G(z)): 0.3026 / 0.1316 [3/5][250/1583] Loss_D: 0.7737 Loss_G: 1.2797 D(x): 0.5589 D(G(z)): 0.0836 / 0.3363 [3/5][300/1583] Loss_D: 0.5120 Loss_G: 1.5623 D(x): 0.7216 D(G(z)): 0.1406 / 0.2430 [3/5][350/1583] Loss_D: 0.5651 Loss_G: 3.2310 D(x): 0.8586 D(G(z)): 0.3048 / 0.0518 [3/5][400/1583] Loss_D: 1.3554 Loss_G: 5.0320 D(x): 0.9375 D(G(z)): 0.6663 / 0.0112 [3/5][450/1583] Loss_D: 0.5939 Loss_G: 1.9385 D(x): 0.6931 D(G(z)): 0.1538 / 0.1785 [3/5][500/1583] Loss_D: 1.5698 Loss_G: 5.0469 D(x): 0.9289 D(G(z)): 0.7124 / 0.0106 [3/5][550/1583] Loss_D: 0.5496 Loss_G: 1.7024 D(x): 0.6891 D(G(z)): 0.1171 / 0.2172 [3/5][600/1583] Loss_D: 2.0152 Loss_G: 6.4814 D(x): 0.9824 D(G(z)): 0.8069 / 0.0031 [3/5][650/1583] Loss_D: 0.6249 Loss_G: 2.9602 D(x): 0.8547 D(G(z)): 0.3216 / 0.0707 [3/5][700/1583] Loss_D: 0.4448 Loss_G: 2.3997 D(x): 0.8289 D(G(z)): 0.2034 / 0.1153 [3/5][750/1583] Loss_D: 0.5768 Loss_G: 2.5956 D(x): 0.8094 D(G(z)): 0.2721 / 0.1032 [3/5][800/1583] Loss_D: 0.5314 Loss_G: 2.9121 D(x): 0.8603 D(G(z)): 0.2838 / 0.0724 [3/5][850/1583] Loss_D: 0.9673 Loss_G: 4.2585 D(x): 0.9067 D(G(z)): 0.5233 / 0.0206 [3/5][900/1583] Loss_D: 0.7076 Loss_G: 2.7892 D(x): 0.7294 D(G(z)): 0.2625 / 0.0909 [3/5][950/1583] Loss_D: 0.4336 Loss_G: 2.8206 D(x): 0.8736 D(G(z)): 0.2363 / 0.0770 [3/5][1000/1583] Loss_D: 0.6914 Loss_G: 1.9334 D(x): 0.6811 D(G(z)): 0.2143 / 0.1734 [3/5][1050/1583] Loss_D: 0.6618 Loss_G: 1.8457 D(x): 0.6486 D(G(z)): 0.1421 / 0.2036 [3/5][1100/1583] Loss_D: 0.6517 Loss_G: 3.2499 D(x): 0.8540 D(G(z)): 0.3491 / 0.0532 [3/5][1150/1583] Loss_D: 0.6688 Loss_G: 3.9172 D(x): 0.9389 D(G(z)): 0.4170 / 0.0269 [3/5][1200/1583] Loss_D: 0.9467 Loss_G: 0.8899 D(x): 0.4853 D(G(z)): 0.1028 / 0.4567 [3/5][1250/1583] Loss_D: 0.6048 Loss_G: 3.3952 D(x): 0.8353 D(G(z)): 0.3150 / 0.0425 [3/5][1300/1583] Loss_D: 0.4915 Loss_G: 2.5383 D(x): 0.7663 D(G(z)): 0.1622 / 0.1071 [3/5][1350/1583] Loss_D: 0.7804 Loss_G: 1.5018 D(x): 0.5405 D(G(z)): 0.0719 / 0.2701 [3/5][1400/1583] Loss_D: 0.6432 Loss_G: 1.5893 D(x): 0.6069 D(G(z)): 0.0576 / 0.2577 [3/5][1450/1583] Loss_D: 0.7720 Loss_G: 3.8510 D(x): 0.9291 D(G(z)): 0.4558 / 0.0299 [3/5][1500/1583] Loss_D: 0.9340 Loss_G: 4.6210 D(x): 0.9556 D(G(z)): 0.5341 / 0.0141 [3/5][1550/1583] Loss_D: 0.7278 Loss_G: 4.0992 D(x): 0.9071 D(G(z)): 0.4276 / 0.0231 [4/5][0/1583] Loss_D: 0.4672 Loss_G: 1.9660 D(x): 0.7085 D(G(z)): 0.0815 / 0.1749 [4/5][50/1583] Loss_D: 0.5710 Loss_G: 2.3229 D(x): 0.6559 D(G(z)): 0.0654 / 0.1285 [4/5][100/1583] Loss_D: 0.8091 Loss_G: 0.8053 D(x): 0.5301 D(G(z)): 0.0609 / 0.4987 [4/5][150/1583] Loss_D: 0.5661 Loss_G: 1.4238 D(x): 0.6836 D(G(z)): 0.1228 / 0.2842 [4/5][200/1583] Loss_D: 0.6187 Loss_G: 1.6628 D(x): 0.6178 D(G(z)): 0.0744 / 0.2292 [4/5][250/1583] Loss_D: 0.9808 Loss_G: 2.0649 D(x): 0.5769 D(G(z)): 0.2623 / 0.1706 [4/5][300/1583] Loss_D: 0.6530 Loss_G: 2.7874 D(x): 0.8024 D(G(z)): 0.3063 / 0.0804 [4/5][350/1583] Loss_D: 0.5535 Loss_G: 2.5154 D(x): 0.7744 D(G(z)): 0.2165 / 0.1023 [4/5][400/1583] Loss_D: 0.5277 Loss_G: 2.1542 D(x): 0.6766 D(G(z)): 0.0801 / 0.1474 [4/5][450/1583] Loss_D: 0.5995 Loss_G: 2.6477 D(x): 0.7890 D(G(z)): 0.2694 / 0.0902 [4/5][500/1583] Loss_D: 0.7183 Loss_G: 1.2993 D(x): 0.5748 D(G(z)): 0.1000 / 0.3213 [4/5][550/1583] Loss_D: 0.4708 Loss_G: 2.0671 D(x): 0.7286 D(G(z)): 0.1094 / 0.1526 [4/5][600/1583] Loss_D: 0.5865 Loss_G: 1.9083 D(x): 0.7084 D(G(z)): 0.1745 / 0.1867 [4/5][650/1583] Loss_D: 1.5298 Loss_G: 4.2918 D(x): 0.9623 D(G(z)): 0.7240 / 0.0197 [4/5][700/1583] Loss_D: 0.9155 Loss_G: 0.9452 D(x): 0.4729 D(G(z)): 0.0575 / 0.4395 [4/5][750/1583] Loss_D: 0.7500 Loss_G: 1.7498 D(x): 0.5582 D(G(z)): 0.0772 / 0.2095 [4/5][800/1583] Loss_D: 0.5993 Loss_G: 2.5779 D(x): 0.7108 D(G(z)): 0.1829 / 0.1063 [4/5][850/1583] Loss_D: 0.6787 Loss_G: 3.6855 D(x): 0.9201 D(G(z)): 0.4084 / 0.0347 [4/5][900/1583] Loss_D: 1.2792 Loss_G: 2.2909 D(x): 0.6365 D(G(z)): 0.4471 / 0.1575 [4/5][950/1583] Loss_D: 0.6995 Loss_G: 3.3548 D(x): 0.9201 D(G(z)): 0.4188 / 0.0488 [4/5][1000/1583] Loss_D: 0.6913 Loss_G: 3.9969 D(x): 0.8630 D(G(z)): 0.3771 / 0.0242 [4/5][1050/1583] Loss_D: 0.7620 Loss_G: 1.7744 D(x): 0.6668 D(G(z)): 0.2290 / 0.2204 [4/5][1100/1583] Loss_D: 0.6901 Loss_G: 3.1660 D(x): 0.8472 D(G(z)): 0.3595 / 0.0593 [4/5][1150/1583] Loss_D: 0.5866 Loss_G: 2.4580 D(x): 0.7962 D(G(z)): 0.2695 / 0.1049 [4/5][1200/1583] Loss_D: 0.8830 Loss_G: 3.9824 D(x): 0.9264 D(G(z)): 0.5007 / 0.0264 [4/5][1250/1583] Loss_D: 0.4750 Loss_G: 2.1389 D(x): 0.8004 D(G(z)): 0.1933 / 0.1464 [4/5][1300/1583] Loss_D: 0.4972 Loss_G: 2.3561 D(x): 0.8266 D(G(z)): 0.2325 / 0.1285 [4/5][1350/1583] Loss_D: 0.6721 Loss_G: 1.1904 D(x): 0.6042 D(G(z)): 0.0839 / 0.3486 [4/5][1400/1583] Loss_D: 0.4447 Loss_G: 2.7106 D(x): 0.8540 D(G(z)): 0.2219 / 0.0852 [4/5][1450/1583] Loss_D: 0.4864 Loss_G: 2.5237 D(x): 0.7153 D(G(z)): 0.1017 / 0.1036 [4/5][1500/1583] Loss_D: 0.7662 Loss_G: 1.1344 D(x): 0.5429 D(G(z)): 0.0600 / 0.3805 [4/5][1550/1583] Loss_D: 0.4294 Loss_G: 2.9664 D(x): 0.8335 D(G(z)): 0.1943 / 0.0689
3.3.5 Results
Finally, let us see how we did it. Here, we will look at three different results. First, we will see how the loss of D and G changes during training. Second, we will visualize the output of G in the fixed_noise batch of each epoch. Third, we will look at a batch of fake data from G next to a batch of actual data.
Loss and training iterations
The following is a graph of the relationship between D&G loss and training iterations.
plt.figure(figsize=(10,5)) plt.title("Generator and Discriminator Loss During Training") plt.plot(G_losses,label="G") plt.plot(D_losses,label="D") plt.xlabel("iterations") plt.ylabel("Loss") plt.legend() plt.show()
G's process visualization
Remember how we save the generator output in fixed_noise batch after each training epoch. Now, we can visualize the training progress of G through animation. Press the play button to start the animation.
#%%capture fig = plt.figure(figsize=(8,8)) plt.axis("off") ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list] ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) HTML(ani.to_jshtml())
Real image vs fake image
Finally, let us take a look at some real images and fake images.
# Get a batch of real images from the data loader real_batch = next(iter(dataloader)) # Draw real images plt.figure(figsize=(15,15)) plt.subplot(1,2,1) plt.axis( "off") plt.title("Real Images") plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu( ),(1,2,0))) # Draw pseudo images in the last epoch plt.subplot(1,2,2) plt.axis("off") plt.title("Fake Images") plt.imshow (np.transpose(img_list[-1],(1,2,0))) plt.show()
Further work
We have completed our entire tutorial, but you can explore further from the following directions. You can:
Train for longer and see how good the results are
Modify this model to obtain a different data set, or change the size of the image and model architecture
Check out some other cool GAN projects here
Create a GAN that generates music
Article source: https://zhiya360.com/docs/pytorchstudy/pytorch-dcgan/01-duikang