理解VAE(变分自编码器)[结合代码]

1.贝叶斯公式

贝叶斯理论的思路是,在主观判断的基础上,先估计一个值(先验概率),然后根据观察的新信息不断修正(可能性函数)。

P(A):没有数据B的支持下,A发生的概率,也叫做先验概率。这完全是根据经验做出的判断,这也是前面说的贝叶斯公式的主观因素部分。

P(A|B):在数据B的支持下,A发生的概率,也叫后验概率。即在B事件发生之后,我们对A事件概率的重新评估。

P(B|A):给定某参数A的概率分布:也叫似然函数。这是一个调整因子,即新信息B带来的调整,作用是使得先验概率更接近真实概率。至于新信息带来的调整作用大不大,还得看因子的值大不大。

假如我在大学校园中随机找出一个人B,身高170,体重60。P(A)是大学中男生的占比,是先验概率。

那么从所有男生中选出身高170,体重60的人的概率就是似然概率,也就是从男生的分布(男生的概率密度函数)中得到B的概率。

那么选出的这个人B是男生的可能性就是后验概率P(A|B)。

无监督学习中的一个核心问题就是——密度估计问题。要训练出一个模型,使该模型的概率密度函数和真实的训练数据分布尽可能相似。

2.生成模型分类

无监督学习两种典型思路:

1.显式密度估计:显式的定义并求解分布Pmodel(x)。分布的方程是能够写出来定义出来的。可以算出特征空间中选取的点生成的样本可信度,概率值。

2.隐式的密度估计:学习一个模型Pmodel(x),无需显式的公式定义。只能够生成样本。只会产生特征空间中概率比较大,与真图相似的点。

生成模型主要分为显示概率密度-可求解显式概率密度-可近似隐式概率密度

2.1 PixelRNN与PixelCNN

PixelRNN与PixelCNN是前面的像素生成后,来预测下一个像素的值,在知道前面的像素值后,该像素值的概率分布(0-255的每个值的概率)是可以计算出来的。

2.2 VAE

编码器可以单独作为一个特征提取网络来进行分类任务。

解码器可以单独分割出来作为一个图像生成器。

2.2.1 香农熵、交叉熵、KL散度

信息量:一个不太可能发生的事件居然发生了,我们收到的信息要多于一个非常可能发生的事件发生。

用一个例子来理解一下,假设我们收到了以下两条消息:

A:今天早上太阳升起

B:今天早上有日食

我们认为消息A的信息量是如此之少,甚至于没有必要发送,而消息B的信息量就很丰富。利用这个例子,我们来细化一下信息量的基本想法:①非常可能发生的事件信息量要比较少,在极端情况下,确保能够发生的事件应该没有信息量;②不太可能发生的事件要具有更高的信息量。事件包含的信息量应与其发生的概率负相关。

事件x的信息量为:

该信息量公式只能处理随机变量的取指定值时的信息量。要对整个概率分布的平均信息量进行描述时用香农熵。

香农熵:具体方法为求上述信息量函数关于概率分布P的期望,这个期望值(即熵)为:

当概率分布连续时,求和号变积分号。

那些接近确定性的分布(输出几乎可以确定)具有较低的熵,那些接近均匀分布的概率分布具有较高的熵。比如硬币两面的概率都是0.5,那么他的熵就高。如果硬币两面概率分别是0.2和0.8,那么他的熵就低。

KL散度:假设随机变量的真实概率分布为P(X),而我们在处理实际问题时使用了一个近似的分布Q(X)来进行建模。由于我们使用的是Q(X)而不是真实的P(X),所以我们在具体化的取值时需要一些附加的信息来抵消分布不同造成的影响。我们需要的平均附加信息量可以使用相对熵,或者叫KL散度(Kullback-Leibler Divergence)来计算,KL散度可以用来衡量两个分布的差异:

交叉熵:交叉熵与上面介绍的KL散度关系很密切,让我们把上面的KL散度公式换一种写法:

交叉熵H(P,Q)就等于: 

如果把P看作随机变量的真实分布的话,KL散度左半部分的−H(P(X))其实是一个固定值,KL散度的大小变化其实是由右半部分交叉熵来决定的,因为右半部分含有近似分布Q,我们可以把它看作网络或模型的实时输出,把KL散度或者交叉熵看做真实标签与网络预测结果的差异,所以神经网络的目的就是通过训练使近似分布Q逼近真实分布P。从理论上讲,优化KL散度与优化交叉熵的效果应该是一样的。所以我认为,在深度学习中选择优化交叉熵而非KL散度的原因可能是为了减少一些计算量,交叉熵毕竟比KL散度少一项。

2.2.2 变分推理

隐变量图模型:隐变量一般表示了一些被观测变量(也就是输入样本)的属性,这样隐变量和被观测变量就组成了隐变量图模型。

由于分母的联合概率密度函数不好求,也就是边缘概率P(x)不好求,所以导致整个后验概率P(z|x)不好求。

推理:一般我们会观测数据来获得对数据的见解或知识,这就是从被观测变量数据推理无法直接观测的隐变量的过程。

变分:既然真实的后验概率分布比较复杂,那么我们尝试使用一组比较简单并且可以参数化的概率分布去近似它。近似的过程就是变分。

变分推理:使用近似的概率分布去完成在给定被观测变量的情况下对隐变量概率分布估计的过程就叫做变分推理。

2.2.3 ELBO

我们想要使得近似分布q尽可能等于后验概率p(z|x),就要使他们之间的KL散度最小。因为logp(x)与隐变量z无关,所以可以移到期望外面,KL散度简化完的公式为

将logp(x)移到左边后为

由于p(x)是被观测变量,在隐空间是一个定值。KL散度又是大于等于0,且分布相同时才等于零。所以要想最小化q和p(z|x)的KL散度,就需要最大化期望 E[logp(z,x)-logq(z)],也就是最大化证据下界ELBO:

我们可以自己选一个符合一般概率分布的近似分布q,通过最大化ELBO(求导求极值)来使得q的曲线逼近真实后验概率分布曲线。

2.2.4 变分自编码器

VAE中的编码器就是后验概率p(z|x),可以通过q(z)来近似,q(z)也可以写作q(z|x)。

VAE中的解码器就是似然函数p(x|z)。

ELBO就可以改写为:

最大化ELBO就需要最大化第一项最小化第二项。

第一项是似然函数在近似分布q下的期望,表示重建的相似度,p(x|z)是一个高斯分布,z是编码器生成的均值,x是输入的样本,高斯分布中均值是概率密度最高的地方,所以最大化第一项相当于使神经网络预测的均值越靠近输入样本。

第二项是近似分布q,也就是编码器,与隐变量先验分布p(z)的KL散度,为了方便计算,我们给定p(z)为单位高斯分布,均值为0,方差是单位矩阵。这里最小化KL散度就是使编码器输出的隐变量分布和单位高斯分布更加接近。

上图是VAE的简要示意图,复杂分布的输入数据通过编码器映射到隐空间中,使得隐变量的概率密度靠近先验分布,也就是单位高斯分布。解码器将简单分布的隐变量通过解码器映射到复杂分布的图像概率空间中。

2.2.5 目标函数

首先是编码器部分的目标函数: 

首先看一下KL散度这一项,q是符合高斯分布的,均值和方差是由神经网络生成的。p(z)符合单位高斯分布。其中\varphi是参数化的神经网络。

两个高斯分布之间的KL散度是有解析解的。代入q和p(z)的高斯分布,化简如下。

然后是解码器部分的目标函数:

其中,p(x|z)是一个高斯分布,\theta是参数化的神经网络,神经网络输出该高斯分布的均值。p(x|z)的log值化简如下,x是样本标签,也就是原图,\mu _{\theta }是神经网络输出值,最后就是求他两的二范数。

但是该目标函数取决于隐变量z,z是从近似分布q中采样得到的,所以在对\theta求梯度反向传播时就不行。

2.2.6 重参数化

上图是VAE的整个流程图,输入数据x通过编码器得到近似分布q的均值和方差,然后与单位高斯分布p(z)计算KL散度。

然后对近似分布q进行采样得到隐变量z,将z送入解码器中得到重建的预测值,然后将预测值与原图计算二范数。但是由于采样是随机的,不能反向传播,所以我们需要重参数化

重参数化就是,任意高斯分布都可以写成均值+标准差*单位高斯分布的形式。所以任意高斯分布的采样都可以写成均值+标准差*单位高斯分布采样的形式。反向传播时,标准高斯分布的采样可以当作常数。

3.VAE代码

训练代码如下,

import os
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image
import numpy as np
import torchvision
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

EPOCH = 10
BATCH_SIZE = 64
path_model = "./VAEmodel.pth"
path_state_dict = "./VAEmodel_state_dict.pth"

im_tfs = tfs.Compose([
    tfs.ToTensor(),
    tfs.Lambda(lambda x: x.repeat(3,1,1)),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化
])

train_set = MNIST('./mnist', transform=im_tfs,download=True)
train_data = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc2 = nn.Linear(400, 100)
        self.fc21 = nn.Linear(100, 2) # mean
        self.fc22 = nn.Linear(100, 2) # var
        self.fc3 = nn.Linear(2, 100)
        self.fc4 = nn.Linear(100, 400)
        self.fc5 = nn.Linear(400, 784)

    def encode(self, x): #编码层
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        return self.fc21(h2), self.fc22(h2)

    # def sampel(self, mu, logvar):
    #     std = logvar.mul(0.5).exp_()  # e**(x*0.5)
    #     eps = torch.FloatTensor(std.size()).normal_(0, 1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_() #e**(x*0.5)
        eps = torch.FloatTensor(std.size()).normal_(0,1)
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):#解码层
        h3 = F.relu(self.fc3(z))
        h4 = F.relu(self.fc4(h3))
        h5 = F.tanh(self.fc5(h4))
        return h5

    def forward(self, x):
        mu, logvar = self.encode(x) # 编码
        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布
        return z, self.decode(z), mu, logvar # 解码,同时输出均值方差
net = VAE() # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()
#469
x, _ = train_set[2] # (3,28,28)
x = x.view(x.shape[0], -1)#(3,784)
if torch.cuda.is_available():
    x = x.cuda()
x = Variable(x)
_, _, mu, var = net(x)
print(_.size())

# -----------------------------------------------------
reconstruction_function = nn.MSELoss(size_average=False)

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

def to_img(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 3, 28, 28)
    return x

# 定义一个函数,将最后的结果转换回图片
def to_img2(x):
    '''
    定义一个函数将最后的结果转换回图片
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x

def main():
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    X = []
    Y = []
    # --------------------------------
    loss2 = nn.BCELoss()
    for e in range(EPOCH):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        batch_count = 0
        for im, label in train_data:
            im = im.view(im.shape[0],3,-1) # (batch, 3, 768)
            im = Variable(im)
            if torch.cuda.is_available():
                im = im.cuda()
            _, recon_im, mu, logvar = net(im)
            loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均


            # print(im.size(),recon_im.size(),label.size())
            # loss3 = loss2(recon_im, label)
            # train_l_sum += loss3.item()
            batch_count += 1

            # 计算损失值
            loss1 = loss.detach().numpy()
            # x.append(epoch)
            Y.append(loss1)

            # 梯度归零
            optimizer.zero_grad()
            # 计算损失函数的梯度
            loss.backward()
            # 更新优化器
            optimizer.step()
            torch.save(net, path_model)
            net_state_dict = net.state_dict()
            torch.save(net_state_dict, path_state_dict)

            if (e + 1) % 2 == 0:
                save1 = to_img(im.cpu().data)
                save = to_img(recon_im.cpu().data)
                if not os.path.exists('./vae_img'):
                    os.mkdir('./vae_img')
                save_image(save1,'./vae_img/image_{}.png'.format(e + 1))
                save_image(save, './vae_img/image_pre{}.png'.format(e + 1))
                # -------------------
                # x, _ = train_set[0]
                # x = x.view(x.shape[0], -1)
                # if torch.cuda.is_available():
                #     x = x.cuda()
                # x = Variable(x)
                # _, _, mu, _ = net(x)
                # print(mu)
        print('epoch %d, loss %.4f' % (e + 1, train_l_sum / batch_count))
    plt.title('train loss')
    plt.plot(Y)
    plt.show()

    view_data = train_set.train_data[:].view(-1, 28 * 28).type(torch.FloatTensor) / 255.
    # encoded_data, _ = autoencoder(view_data)  # 提取压缩的特征值
    z, _, _, _ =  net(view_data)
    fig = plt.figure(2)
    # ax = Axes3D(fig)  # 3D 图
    # x, y, z 的数据值
    x = z.data[:, 0].numpy()
    y = z.data[:, 1].numpy()
    # Z = encoded_data.data[:, 2].numpy()
    values = train_set.train_labels[:].numpy()  # 标签值
    # for x, y, s in zip(X, Y, values):
    #     cs = cm.rainbow(int(255 * s / 9))  # 上色
    # 绘制散点图,x为x轴坐标,y为y轴坐标,values为颜色值,marker为点型
    plt.scatter(x,y,c=values,marker='.')
    # 设置图片大小
    plt.colorbar()
    #     ax.text(x, y, z, s, backgroundcolor=c)  # 标位子
    # ax.set_xlim(X.min(), X.max())
    # ax.set_ylim(Y.min(), Y.max())
    # ax.set_zlim(Z.min(), Z.max())
    plt.show()

if __name__ == "__main__":
    main()

预测代码如下,

import torch
from VAE1 import *
path_model = "./VAEmodel.pth"
net_load = torch.load(path_model)
print(net_load)


gridx = 10
gridy = 20

def img_iter():
    for i in range(gridx*gridy):

        code = Variable(torch.FloatTensor([[(gridx*gridy-i**2), (gridx*gridy-i**2)]]).to(device))
        decode = net_load.decode(code)
        # print(decode.size(), '2')
        decode_img = to_img2(decode).squeeze()
        # print(decode_img.size(), '3')
        decode_img = decode_img.data.cpu().numpy() * 255
        decode_img = torch.from_numpy(decode_img)
        # decode_img = torch.view(1,28,28)
        # print(decode_img.size(), '4')
        decode_img = decode_img.unsqueeze(dim=0)
        decode_img = decode_img.unsqueeze(dim=0)

        # print(decode_img.size(), '5')

        if i >= 1:
            all_img = torch.cat((all_img, decode_img), 0)
            # all_img = torch.stack([all_img,decode_img],dim=0)
            # torch.stack([a, b], dim=0)
            # print(all_img.size(),'10'+str(i))
        elif i < 1:
            all_img = decode_img

    return all_img

def imshow(img):
    print(img.shape)
    # img = img / 2 + 0.5  # unnormalize
    # npimg = img.numpy()
    # plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()
code = Variable(torch.FloatTensor([[-10, -3]]).to(device))
decode = net_load.decode(code)
decode_img = to_img2(decode).squeeze()
decode_img = decode_img.data.cpu().numpy() * 255
# all_decode_img = img_iter()
# print(all_decode_img.size(),'1')
# imshow(torchvision.utils.make_grid(all_decode_img, nrow=15, padding=1))
plt.imshow(decode_img.astype('uint8'), cmap='gray')
plt.show()

猜你喜欢

转载自blog.csdn.net/Orange_sparkle/article/details/134872648