看这篇文章之前,希望能够熟知以下我写的一遍原理介绍,代码中的名称尽可能和原论文保持一致,如ze,z,zq这些名称等。
VQ-VAE原理
理论不在重新介绍。同样仍然使用mnist数据集。
1. 采样算法模型 pixel cnn
我们知道VQ-VAE并不能像VAE一样自己生产一个随机采样,说白了他不是一个VAE,需要一个辅助模型来生成一个随机采样,论文中用的pixel cnn,当然如果你想了解pixel cnn请看之前写过的一遍介绍,pixel cnn原理 当然本文用的是带有Gate结构的pixel cnn。我们之间看代码。
import torch.nn as nn
import torch
from torchinfo import summary
from torch.functional import F
class VerticalMaskConv2d(nn.Module):
def __init__(self, *args, **kwags):
super().__init__()
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[0:H // 2 + 1] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)
def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res
class HorizontalMaskConv2d(nn.Module):
def __init__(self, conv_type, *args, **kwags):
super().__init__()
assert conv_type in ('A', 'B')
self.conv = nn.Conv2d(*args, **kwags)
H, W = self.conv.weight.shape[-2:]
mask = torch.zeros((H, W), dtype=torch.float32)
mask[H // 2, 0:W // 2] = 1
if conv_type == 'B':
mask[H // 2, W // 2] = 1
mask = mask.reshape((1, 1, H, W))
self.register_buffer('mask', mask, False)
def forward(self, x):
self.conv.weight.data *= self.mask
conv_res = self.conv(x)
return conv_res
class GatedBlock(nn.Module):
def __init__(self, conv_type, in_channels, p, bn=True):
super().__init__()
self.conv_type = conv_type
self.p = p
self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, 1)
self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
1)
self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
self.h_output_conv = nn.Conv2d(p, p, 1)
self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()
def forward(self, v_input, h_input):
v = self.v_conv(v_input)
v = self.bn1(v)
v_to_h = v[:, :, 0:-1]
v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
v_to_h = self.v_to_h_conv(v_to_h)
v_to_h = self.bn2(v_to_h)
v1, v2 = v[:, :self.p], v[:, self.p:]
v1 = torch.tanh(v1)
v2 = torch.sigmoid(v2)
v = v1 * v2
h = self.h_conv(h_input)
h = self.bn3(h)
h = h + v_to_h
h1, h2 = h[:, :self.p], h[:, self.p:]
h1 = torch.tanh(h1)
h2 = torch.sigmoid(h2)
h = h1 * h2
h = self.h_output_conv(h)
h = self.bn4(h)
if self.conv_type == 'B':
h = h + h_input
return v, h
class GatedPixelCNN(nn.Module):
def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
super().__init__()
self.block1 = GatedBlock('A', 1, p, bn)
self.blocks = nn.ModuleList()
for _ in range(n_blocks):
self.blocks.append(GatedBlock('B', p, p, bn))
self.relu = nn.ReLU()
self.linear1 = nn.Conv2d(p, linear_dim, 1)
self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
self.out = nn.Conv2d(linear_dim, color_level, 1)
def forward(self, x):
v, h = self.block1(x, x)
for block in self.blocks:
v, h = block(v, h)
x = self.relu(h)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
x = self.out(x)
return x
class PixelCnnWithEmbedding(GatedPixelCNN):
def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
super().__init__(n_blocks, p, linear_dim, bn, color_level)
self.embedding = nn.Embedding(color_level, p)
self.block1 = GatedBlock('A', p, p, bn)
def forward(self, x):
"""
x: (N, H, W), 离散编码z作为输入
return: (N, 256, H, W)
"""
x = self.embedding(x)
x = x.permute(0, 3, 1, 2).contiguous()
return super().forward(x)
if __name__ == '__main__':
# net1 = GatedPixelCNN(15, 128, 32)
# net1.block1 = GatedBlock('A', 128, 128, True)
# summary(net1, input_size=(1, 128, 28, 28))
net2 = PixelCnnWithEmbedding(15, 128, 32)
input_data = torch.randint(0,256, (1,28,28))
summary(net2, input_data=input_data)
运行之后可以看到模型大致结构,输入和输出的w,h是一致的,这里值得注意的是,此处的pixel cnn是用来生成离散变量z的,而不是原始图片,因为我们本意是获得随机采样的离散变量。
2. VQ-VAE 模型
辅助采样算法有了,那接下来就是VQ-VAE,直接看代码。
import torch.nn as nn
import torch
# 1. 残差块
class ResidualBlock(nn.Module):
def __init__(self, dim):
super(ResidualBlock, self).__init__()
self.res_block = nn.Sequential(
nn.ReLU(),
nn.Conv2d(dim, dim, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(dim, dim, 1)
)
def forward(self, x):
x = x + self.res_block(x)
return x
class VQVAE(nn.Module):
def __init__(self, input_dim, dim, n_embedding):
"""
input_dim: 输入通道数,比如3,输入的图片是3通道的
dim:编码后ze的通道数
n_embedding:code book 向量的个数
"""
super(VQVAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(input_dim, dim, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(dim, dim, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(dim, dim, 3, 1, 1),
ResidualBlock(dim),
ResidualBlock(dim)
)
self.decoder = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1),
ResidualBlock(dim),
ResidualBlock(dim),
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1)
)
self.n_downsample = 2
# code book
self.vq_embedding = nn.Embedding(n_embedding, dim)
self.vq_embedding.weight.data.uniform_(-1.0 / n_embedding, 1.0/n_embedding)
def forward(self, x):
"""
x, shape(N,C0,H0,W0)
"""
# encoder (N,C,H,W)
ze = self.encoder(x)
# code book embedding [K, C]
embedding = self.vq_embedding.weight.data
N, C, H, W = ze.shape
K, _ = embedding.shape
embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
ze_broadcast = ze.reshape(N, 1, C, H, W)
# 最近距离, 这一步旨在求得zq,这里通过先求ze->z,在求z->zq,事实上z只作为中间变量,通过(zq-ze).detach从计算图分离,避开不能的反向传播
distance = torch.sum((embedding_broadcast - ze_broadcast) ** 2, 2) # (N,K,H,W)
nearest_neghbor = torch.argmin(distance, 1) # (N,H,W)
# zq (N, C, H, W) : (N, H, W, C) -> (N, C, H, W)
zq = self.vq_embedding(nearest_neghbor).permute(0, 3, 1, 2)
# sg(zq - ze)
decoder_input = ze + (zq - ze).detach()
# decoder
x_hat = self.decoder(decoder_input)
return x_hat, ze, zq
# encode z 这一步指在得到离散变量,类似于像素值, 作为输入和标签好用来训练pixel cnn, pixel cnn的目的是用来重建z的,生成z
@torch.no_grad()
def encode_z(self, x):
ze = self.encoder(x)
embedding = self.vq_embedding.weight.data
# ze: [N, C, H, W]
# embedding [K, C]
N, C, H, W = ze.shape
K, _ = embedding.shape
embedding_broadcast = embedding.reshape(1, K, C, 1, 1)
ze_broadcast = ze.reshape(N, 1, C, H, W)
distance = torch.sum((embedding_broadcast - ze_broadcast)**2, 2)
nearest_neighbor = torch.argmin(distance, 1)
return nearest_neighbor
# decode z 这一步指在从pixelcnn得到的结果latent生成最终结果, 因为pixel cnn的结果生成的latent 是离散的z
@torch.no_grad()
def decode_z(self, latent_z):
"""
latent: shape, (N, H, W)
"""
# zq (N, C, H, W)
zq = self.vq_embedding(latent_z).permute(0,3,1,2)
x_hat = self.decoder(zq)
return x_hat
# shape: [C,H,W]
def get_latent_HW(self, input_shape):
C, H, W = input_shape
return H // 2 ** self.n_downsample, W // 2 ** self.n_downsample
if __name__ == '__main__':
from torchinfo import summary
vqvae = VQVAE(1, 32, 32)
summary(vqvae, input_size=[1,1,28,28])
这里没什么好说的,注意事项全部在代码注释里,特别需要注意的是其中ze,z,zq三者之间的转换,以及detach来分离计算图的使用技巧。
3. 训练及推理代码
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import cv2
from vqvae import VQVAE
from pixelcnn import GatedPixelCNN, PixelCnnWithEmbedding
import einops
import numpy as np
# 依然拿mnist 作为数据集
# 看一下mnist的样子
def mnist_show():
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True)
print('length of MNIST', len(mnist))
img, label = mnist[0]
print(img)
print(label)
img.show()
tensor = transforms.ToTensor()(img)
print(tensor.shape) # torch.Size([1, 28, 28]) CHW
print(tensor.max()) # max 1,
print(tensor.min()) # min 0, 已经是归一化的结果
# mnist_show()
def train_vqvae(
model:VQVAE,
device,
dataloader,
ckpt_vqvae='vqvae_ckpt.pth',
n_epochs=100,
alpha=1,
beta=0.25,
):
model.to(device) # model = model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
mse_loss = torch.nn.MSELoss()
print("start vqvae train...")
for epo in range(n_epochs):
for img, label in dataloader:
x = img.to(device) # N1HW
x_hat, ze, zq = model(x)
# ||x - decoder(ze+sg(zq-ze))||
loss_rec = mse_loss(x, x_hat)
# ||zq - sg(ze)||
loss_zq = mse_loss(zq, ze.detach())
# ||sg(zq) - ze||
loss_ze = mse_loss(zq.detach(), ze)
loss = loss_rec + alpha * loss_zq + beta * loss_ze
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch:{
epo}, loss:{
loss.item():.6f}")
if epo % 10 == 0:
torch.save(model.state_dict(), ckpt_vqvae)
print("vqvae train finish!!")
def train_gen(
vqvae:VQVAE,
model,
device,
dataloader,
ckpt_gen="gen_ckpt.pth",
n_epochs=50,
):
vqvae = vqvae.to(device)
model = model.to(device)
vqvae.eval()
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
print("start pixel cnn train...")
for epo in range(n_epochs):
for x, _ in dataloader:
with torch.no_grad():
x = x.to(device)
# 得到离散变量z
z = vqvae.encode_z(x)
# 使用pixel cnn重建这个离散变量z,记住是重建的z 而非x 即由z->z
predict_z = model(z)
loss = loss_fn(predict_z, z)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch:{
epo}, loss:{
loss.item():.6f}")
if epo % 10 == 0:
torch.save(model.state_dict(), ckpt_gen)
print("pixel train finish!!")
# 看一下vae 的效果
def reconstruct(model, x, device):
model.to(device)
model.eval()
with torch.no_grad():
x_hat, _, _ = model(x)
n = x.shape[0]
n1 = int(n**0.5)
x_cat = torch.concat((x, x_hat), 3)
x_cat = einops.rearrange(x_cat, '(n1 n2) c h w -> (n1 h) (n2 w) c', n1=n1)
x_cat = (x_cat.clip(0, 1) * 255).cpu().numpy().astype(np.uint8)
cv2.imwrite(f'reconstruct_show.jpg', x_cat)
# 看一下最终生成的效果
def sample_imgs(
vqvae:VQVAE,
gen_model,
img_shape,
device,
n_sample=81
):
vqvae = vqvae.to(device)
gen_model = gen_model.to(device)
vqvae.eval()
gen_model.eval()
# 获取latent space H,W
C,H,W = img_shape
H, W = vqvae.get_latent_HW((C,H,W))
input_shape = (n_sample, H, W)
latent_z = torch.zeros(input_shape).to(device).to(torch.long)
# pixel cnn sample
with torch.no_grad():
for i in range(H):
for j in range(W):
output = gen_model(latent_z)
prob_dist = torch.softmax(output[:, :, i, j], -1)
pixel = torch.multinomial(prob_dist, 1)
latent_z[:, i, j] = pixel[:, 0]
# vqvae decode 由z->x_hat
imgs = vqvae.decode_z(latent_z)
imgs = imgs * 255
imgs = imgs.clip(0, 255)
imgs = einops.rearrange(imgs,
'(n1 n2) c h w -> (n1 h) (n2 w) c',
n1=int(n_sample**0.5))
imgs = imgs.detach().cpu().numpy().astype(np.uint8)
cv2.imwrite('sample_show.jpg', imgs)
def main():
""" 代码中的公式符号尽可能和原论文一致,避免混淆,尤其是ze,z,zq这几个概念 """
device = torch.device("cuda:0")
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist, batch_size=512, shuffle=True)
# 0. 构建模型
vqvae = VQVAE(1, 32, 32)
gen_model = PixelCnnWithEmbedding(15, 128, 32)
# 1. train vqvae , reconstruct
train_vqvae(vqvae, device, dataloader)
# 2. train gen model, sample
vqvae.load_state_dict(torch.load('vqvae_ckpt.pth'))
train_gen(vqvae, gen_model, device, dataloader)
gen_model.load_state_dict(torch.load('gen_ckpt.pth'))
def test():
# 训练完成,测试一下效果
device = torch.device("cuda:0")
mnist = torchvision.datasets.MNIST(root='./data/mnist', download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)
batch_imgs, _ = next(iter(dataloader))
# vqvae
vqvae = VQVAE(1, 32, 32)
vqvae.load_state_dict(torch.load('vqvae_ckpt.pth'))
vqvae.eval()
vqvae = vqvae.to(device)
batch_imgs = batch_imgs.to(device)
reconstruct(vqvae, batch_imgs, device)
gen_model = PixelCnnWithEmbedding(15, 128, 32)
gen_model.load_state_dict(torch.load('gen_ckpt.pth'))
gen_model.eval()
gen_model = gen_model.to(device)
sample_imgs(vqvae, gen_model, (1, 28, 28), device)
if __name__ == '__main__':
main()
# 训练完成后,运行test()测试效果
# test()
这里主要需要看三个损失函数是怎么做的,其他没什么注意的。
4. 测试效果
先看vq-vae重建效果:说实话一般的AE重建效果与其差不多,没什么参考意义
在看生成效果:这个效果比之前的VAE还有pixel cnn和GAN的效果都要好,当然我这只是训练了50个epoch的结果,可以训练更久一点,想必会更好。
参考
https://github.com/SingleZombie/DL-Demos/blob/master/dldemos/VQVAE/main.py
当然里面很多变量名以及方法做了一定的修改,为了是和论文保持一致,好理解一些。再次说明一下,一定理解ze,z,zq这三者之间的关系转换,以及应用阶段。