Pytorch从0到1之变分自编码器——(10)

开篇

这次我们来说一说变分自编码器。变分编码器也是一种很常见的网络结构。它的作用和GAN有些类似,都是为我们生成一张可以"以假乱真"的图片。但是VAE与GAN不同的是,它不用区分生成器和区分器,他在一个网络中完成整个过程。
我们首先输入图片,对他进编码,然后通过我们的网络结构生成编码的方差与均值,然后再解码生成图片,这里最重要的是这个方差和均值的生成。自己刚刚复现了一遍,感觉这里还是挺多需要了解和掌握的地方的,并且也是一种不错的设计思路。
具体的详细信息大家可以参考VAE介绍。这里我们这只说代码。

VAE变分自编码器

库的引入

import os
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import transforms
import torch
from torchvision.utils import save_image

配置设备以及设置保存图片的地址

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sample_dir = 'sample_dir'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

超参数的定义

image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

这里说明一下,h_dim指的是第一个隐藏层,也就是输入图片之后第一个经过的隐藏层的输出特征尺寸。z_dim表示的是预测方差和均值的网络层的输出特征尺寸,大家可以理解成是我们的均值和方差的尺寸大小。
数据准备和加载

data = torchvision.datasets.MNIST(root = '../../data/',
                                  download = True,
                                  train = True,
                                  transform = transforms.ToTensor())

data_loader = torch.utils.data.DataLoader(dataset = data,
                                          shuffle = True,
                                          batch_size = batch_size)

VAE模型的构建

class VAE(nn.Module):
    def __init__(self,image_size,h_dim = 400,z_dim = 20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size,h_dim)
        self.fc2 = nn.Linear(h_dim,z_dim)
        self.fc3 = nn.Linear(h_dim,z_dim)
        self.fc4 = nn.Linear(z_dim,h_dim)
        self.fc5 = nn.Linear(h_dim,image_size)
    def encode(self,x):
        h = F.relu(self.fc1(x))
        return self.fc2(h),self.fc3(h)
    def reparameterize(self,mu,log_var):
        std = torch.exp(log_var / 2)
        ep = torch.randn_like(std)
        return mu + ep * std

    def decode(self,z):
        h = F.relu(self.fc4(z))
        return F.relu(self.fc5(h))
    def forward(self, x):
        mu,log_var = self.encode(x)
        z = self.reparameterize(mu,log_var)
        x_reconst = self.decode(z)
        return x_reconst,mu,log_var

encode操作就是在计算均值和方差,然后将其传输给reparameterize函数,将方差和均值做一步处理,得到了一个由方差和均值组成的式子,然后将其传给decode函数即可解码得到图片。
定义模型和优化器

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate)

训练模型

for epoch in range(num_epochs):
    for i,(images,_) in enumerate(data_loader):
        x = images.to(device).view(-1,image_size)
        x_reconst,mu,log_var = model(x)
        loss_reconst = F.binary_cross_entropy(x_reconst,x,size_average=True)
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = x_reconst + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i + 1) % 10 == 0:
            print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                  .format(epoch + 1, num_epochs, i + 1, len(data_loader), loss_reconst.item(), kl_div.item()))

我们的损失不仅仅是由我们重新生成的图片和原有真实图片的二分交叉熵损失构成,还由KL散度构成,这个公式大家可以参照上面分享的链接,内部有详细的讲解。
二者加和即得到了我们的损失,然后将损失反向传播并进行优化,即完成了训练。
测试模型

with torch.no_grad():
    # Save the sampled images
    # 生成一组正态分布的随机数
    z = torch.randn(batch_size, z_dim).to(device)
    out = model.decode(z).view(-1, 1, 28, 28)
    save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch + 1)))

    # Save the reconstructed
    out, _, _ = model(x)
    x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
    save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch + 1)))

随机生成一组正态分布的数,均值为0,方差为1.然后将这个编码喂入我们的模型,将生成的图片存储下来,再将真实的图片喂入模型,将二者连接起来方便对比。再将真实的存入。

总结

VAE是一个和GAN功能很相似的网络结构,我们可以借助对GAN的理解好好理解一下VAE,主要是要弄清损失函数的由来和计算方差和均值的意义。

原创文章 101 获赞 13 访问量 2311

猜你喜欢

转载自blog.csdn.net/weixin_44755413/article/details/105911392