pytorch five: generating confrontation network -mnist

GAN solved the famous problem of unsupervised learning: a given number of samples, the training system can generate a similar new samples.

Generated against network mainly consists of the following two sub-networks:

  • Builder: generate a random noise to generate a picture
  • Discriminator: determines whether the input picture is true or false picture picture

 

 Alternate Training:

  • Training discriminator requires the use of false images and real images generated by the generator, discriminator want to determine the true picture is true as far as possible, discrimination generator generates a picture as possible is false. (Discriminator can distinguish between true and false hope as much as possible)
  • When training generator, only need to use the generator to generate a picture, the picture generated by the generator into the discriminator, the discriminator determine which is true as far as possible. (Builder want to generate a picture as possible is true)

Training to a certain stage, discrimination and the generator will achieve a balance. That is generated at this time generated picture enough real ones, enough to deceive the discriminator.

For generators, network structure similar to the following, of course, the specific number of channels, the step size, nuclear size, filling and the like, can be appropriately modified according to the specific examples.

The above input is a 100-dimensional network noise, the output is a 3x64x64 picture. Here can be seen as input a picture 100x1x1 slowly increased to 4x4,8x8,16x16,32x32,64x64 by deconvolution (transposition convolution). This deconvolution approach can be understood as information of the image stored in the 100 vectors of neural network based on the 100 information vector described the first steps of deconvolution to outline the contours, color and other basic information, after few details deconvolution slowly improving. The deeper the network, the more detailed the details.

Transposed convolution FIG size is characterized Hout = (Hin-1) x S + K - 2P (S increments Stride, K is the size of the nuclear Kernel, P is the filling layer Padding).

 

step:

  • Definition Model
  • Loading
  • Configuration parameters
  • Model training

 

Method a: full mesh layer neural network 

Definition Model: mnist_model.py named file, into the models folder

#判别器
#将图片28*28展开成784,然后通过多层感知器,最后接sigmoid激活得到0到1之间的概率进行二分类
from torch import nn
class NetD(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
        self.main = nn.Sequential(
            nn.Linear(784,256),  #输入特征为784,输出为256
            nn.LeakyReLU(0.2,inplace=True),

            nn.Linear(256,256), #进行一个线性映射
            nn.LeakyReLU(0.2,inplace=True),

            nn.Linear(256,1),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        x = self.main(x)
        return x

#随机输入一个100维的噪声,噪声为均值为0方差为1的高斯分布
class NetG(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
        self.main = nn.Sequential(
            nn.Linear(100,256),#用线性变换将输入映射到256
            nn.ReLU(inplace=True),
            
            nn.Linear(256,256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256,784),
            nn.Tanh()   #Tanh激活使得生成数据分布在【-1,1】之间
        )
    def forward(self,x):
        x = self.main(x)
        return x

Loading:

from torch.utils import data
from torchvision import datasets
from torchvision import transforms as T

img_transform=T.Compose([
    T.ToTensor(),
    T.Normalize(mean=([0.5]),std=([0.5]))
    ])

mnist = datasets.MNIST(root='.data/',train=True,transform=img_transform)
mnist_loader = data.DataLoader(mnist,256,shuffle=True,drop_last=True)

Parameters: the name config.py file, put config folder

class Config(object):
    data_path = 'data/'#数据集存放路径
    num_workers = 4    #多进程加载数据所用的进程数
    image_size = 96    #图片尺寸
    batch_size = 256   #批量大小 
    max_epoch = 1000
    lr1 = 4e-4       #生成器的学习率
    lr2 = 4e-4       #判别器的学习率
    nz = 100         #噪声维度
    ngf = 64         #生成器feature map数
    ndf = 64         #判别器feature map数
    
    save_path = 'imgs/' #生成器图片保存路径
    
    d_every = 1         #每一个batch训练一次判别器
    g_every = 5         #每5个batch训练一次生成器
    decay_every = 10    #每10个epoch保存一次模型
    
    #预训练模型路径
    netd_path = None 
    netg_path = None
    
    #测试时用的参数
    gen_img = 'result.png'
    #从128张生成的图片中保存最好的24张
    gen_num = 24
    gen_search_num = 128
    gen_mean = 0       #噪声的均值
    gen_std  = 1        #噪声的标准差

Model training: main.py file named

#导入相关包
import fire
import torch as t
from torch.autograd import Variable as V
from torch.utils import data
from torchvision import datasets
from torchvision import transforms as T
from models.DCGAN_mnist_model import NetG,NetD
from config.config import Config

#数据加载
img_transform=T.Compose([
    T.ToTensor(),
    T.Normalize(mean=([0.5]),std=([0.5]))
    ])
mnist = datasets.MNIST(root='.data/',train=True,transform=img_transform)
mnist_loader = data.DataLoader(mnist,256,shuffle=True,drop_last=True)


#模型训练
opt = Config()
def train(**kwargs):
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
        
    #step1:模型
    netg = NetG()
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path))        
    netd = NetD()
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path))  
    
    #step2:定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),lr = opt.lr1)
    optimizer_d = t.optim.Adam(netd.parameters(),lr = opt.lr2)
    
    #BCELoss:Binary CrossEntropyLoss的缩写,是CrossEntropyLoss的一个特例,只用于二分类问题
    criterion = t.nn.BCELoss()
    
    #真图片label为1,假图片label为0,noises为生成网络的输入噪声
    true_labels = V(t.ones(opt.batch_size))
    fake_labels = V(t.zeros(opt.batch_size))
    
    for epoch in range(opt.max_epoch):
        for i,(datas,labels) in enumerate(mnist_loader):
            num_imgs = len(datas)
            real_img = V(datas.view(num_imgs,-1))
            
            #训练判别器
            #尽可能把真图片判别为1
            output = netd(real_img)
            error_d_real = criterion(output,true_labels)
                
            #尽可能把假图片判别为0
            noises = V(t.randn(num_imgs,opt.nz))
            #训练判别器时需要对生成器生成的图片用detach操作进行计算图截断,避免反向传播将梯度传到生成器中,因为训练判别器时我们不需要训练生成器,也就不需要生成器的梯度。
            fake_img = netg(noises).detach()
            fake_out = netd(fake_img)
            error_d_fake = criterion(fake_out,fake_labels)
            
            d_loss = error_d_real + error_d_fake
            #梯度清零    
            optimizer_d.zero_grad()
            #反向传播
            d_loss.backward()
            #梯度更新
            optimizer_d.step()
                
            #训练生成器                        
            noises = V(t.randn(num_imgs,opt.nz))
            fake_img = netg(noises)
            fake_output = netd(fake_img)
            #尽可能让判别器把假图片也判别为1
            error_g = criterion(fake_output,true_labels)
            optimizer_g.zero_grad()
            error_g.backward()
            optimizer_g.step()
            
        #保存模型
        if epoch % opt.decay_every==0:
            t.save(netd.state_dict(),'checkpoints/netd_%s.pth' %epoch)
            t.save(netg.state_dict(),'checkpoints/netg_%s.pth' %epoch)

#加载训练好的模型,并利用噪声随机生成图片
def generate(**kwargs):
    for k,v in kwargs.items():
        setattr(opt,k,v)
    
    netg,netd = NetG().eval(),NetD().eval()
    noises = t.randn(opt.gen_search_num,opt.nz)
    with t.no_grad():
        nosies = V(noises)
    #加载预训练模型
    netd.load_state_dict(t.load(opt.netd_path))
    netg.load_state_dict(t.load(opt.netg_path))
    
    #生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data.squeeze()
    #挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1].squeeze() #[0]为前k个最大的数,[1]为其对应的索引
    from torchvision.utils import save_image
    fake_img = fake_img*0.5 + 0.5
    fake_img = fake_img.clamp(0,1)
    for i in indexs:
        save_image((fake_img.data[i].view(28,28)),filename='imgs/%d.png' %(i))
    return fake_img

if __name__=='__main__':
        fire.Fire()

Enter the command python main.py train training, training people to a certain epoch may be terminated, after training input python generate netd_path = 'checkpoints / netd_50.pth' netg_path = 'checkpoints / netg_50.pth'

Generation against generation network when viewing iteration 50 epoch mnist handwriting as follows:

50epoch
50epoch

Change netd_path and netg_path path, respectively, see mnist iteration 100,150,200 th epoch GAN generated:

100epoch

 

150epoch

 

200epoch

 

 

Method two: using depth convolutional neural network

Model Definition:

from torch import nn
#生成器网络
class NetG(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100,128,3,1,0,bias=False), #输出3*3
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128,64,3,2,0,bias=False),#输出7*7
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64,32,3,2,0,bias=False),#输出15*15
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(32,1,3,2,2,bias=False),#输出单通道的27*27
            nn.Tanh()  #通过双曲正切函数将输出映射到【-1,1】之间
        )
    def forward(self,x):
        x = self.main(x)
        return x

#判别器网络
class NetD(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        
        self.main = nn.Sequential(
            nn.Conv2d(1,32,3,2,2,bias=False),#输出为15*15
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(32,64,3,2,0,bias=False), #输出为7*7
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,128,3,2,0,bias=False), #输出为3*3
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2,inplace=True),
            
            nn.Conv2d(128,1,3,1,0,bias=False), #输出位1*1
            nn.Sigmoid()    #映射为0到1   
        )
    def forward(self,x):
        x = self.main(x)
        return x

Loading:

#深度卷积网络数据预处理方法
img_transform2=T.Compose([
    T.CenterCrop(27),
    T.ToTensor(),
    T.Normalize(mean=([0.5]),std=([0.5]))
    ])
mnist = datasets.MNIST(root='.data/',train=True,transform=img_transform2)
mnist_loader = data.DataLoader(mnist,256,shuffle=True,num_workers=4,drop_last=True)

Model training:

#导入相关包
import fire
import torch as t
from torch.utils import data
from torch.autograd import Variable as V
from torchvision import datasets
from torchvision import transforms as T
from models.DCGAN_mnist_model import NetG,NetD
from config.config import Config


#数据加载
opt = Config()

img_transform=T.Compose([
    T.CenterCrop(27),
    T.ToTensor(),
    T.Normalize(mean=([0.5]),std=([0.5]))
    ])
mnist = datasets.MNIST(root='.data/',train=True,transform=img_transform)
mnist_loader = data.DataLoader(mnist,256,shuffle=True,num_workers=opt.num_workers,drop_last=True)


#模型训练
def train(**kwargs):
    for k_,v_ in kwargs.items():
        setattr(opt,k_,v_)
        
    #step1:模型
    netg = NetG()
    if opt.netg_path:
        netg.load_state_dict(t.load(None))        
    netd = NetD()
    if opt.netd_path:
        netd.load_state_dict(t.load(None))  
    
    #step2:定义优化器和损失
    optimizer_g = t.optim.Adam(netg.parameters(),lr = opt.lr1)
    optimizer_d = t.optim.Adam(netd.parameters(),lr = opt.lr2)
    
    #BCELoss:Binary CrossEntropyLoss的缩写,是CrossEntropyLoss的一个特例,只用于二分类问题
    criterion = t.nn.BCELoss()
    
    #真图片label为1,假图片label为0,noises为生成网络的输入噪声
    true_labels = V(t.ones(opt.batch_size))
    fake_labels = V(t.zeros(opt.batch_size))
    
    for epoch in range(opt.max_epoch):
        for i,(datas,labels) in enumerate(mnist_loader):
            num_imgs = len(datas)
            real_img = V(datas) #用全连接这个地方要改成V(datas.view(num_imgs,-1))
            
            #训练判别器
            #尽可能把真图片判别为1
            output = netd(real_img)
            error_d_real = criterion(output,true_labels) 
            #尽可能把假图片判别为0
            noises = V(t.randn(num_imgs,opt.nz,1,1)) #用全连接这个地方要改成V(t.randn(num_imgs,opt.nz)
            fake_img = netg(noises).detach()
            fake_out = netd(fake_img)
            error_d_fake = criterion(fake_out,fake_labels)
            
            d_loss = error_d_real + error_d_fake
            optimizer_d.zero_grad()
            d_loss.backward()
            optimizer_d.step()
                
            #训练生成器                        
            fake_img = netg(noises)
            fake_output = netd(fake_img)
            #尽可能让判别器把假图片也判别为1
            error_g = criterion(fake_output,true_labels)
            optimizer_g.zero_grad()
            error_g.backward()
            optimizer_g.step()
            
        #保存模型
        if epoch % opt.decay_every==0:
            print('epoch:{迭代次数}'.format(迭代次数=epoch))
            t.save(netd.state_dict(),'checkpoints2/netd_%s.pth' %epoch)
            t.save(netg.state_dict(),'checkpoints2/netg_%s.pth' %epoch)


#加载训练好的模型,并利用噪声随机生成图片  
def generate(**kwargs):
    for k,v in kwargs.items():
        setattr(opt,k,v)
    
    netg,netd = NetG().eval(),NetD().eval()
    noises = t.randn(opt.gen_search_num,opt.nz)
    with t.no_grad():
        nosies = V(noises)
    #加载预训练模型
    netd.load_state_dict(t.load(opt.netd_path))
    netg.load_state_dict(t.load(opt.netg_path))
    
    #生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).data.squeeze()
    #挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1].squeeze() #[0]为前k个最大的数,[1]为其对应的索引
    from torchvision.utils import save_image
    fake_img = fake_img*0.5 + 0.5
    fake_img = fake_img.clamp(0,1)
    #fake_img = fake_img.view(-1,1,28,28)
    for i in indexs:
        save_image((fake_img.data[i].view(28,28)),filename='imgs/%d.png' %(i))

if __name__=='__main__':
        fire.Fire()

Iteration 10epoch and 20epoch generated mnist fonts as follows:

10epoch

 

20epoch

 Depth convolutional neural network iterative 20epoch be able to get good results and almost no noise, and full convolution neural network iterative 200epoch able to achieve good results there is noise. So the depth of convolution neural network can achieve better results.

Guess you like

Origin blog.csdn.net/qq_24946843/article/details/89818062