DDPM代码详解
简单案例
一、初始化数据集
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch
# 生成S曲线数据,大小为10^4个点,噪声水平为0.1
s_curve,_ = make_s_curve(10**4,noise=0.1)
# 只取X和Z坐标(第二个和第四个坐标),并缩小10倍,使数据更加紧凑
s_curve = s_curve[:,[0,2]]/10.0
print("shape of s:",np.shape(s_curve)) # => (10000, 2)
# 将数据转置,以符合matplotlib的预期输入格式(每列一个维度)
data = s_curve.T
fig,ax = plt.subplots()
# 在轴上绘制散点图,颜色为蓝色,边缘颜色为白色
ax.scatter(*data,color='blue',edgecolor='white');
# 关闭坐标轴的显示
ax.axis('off')
# 将numpy数组s_curve转换为PyTorch张量,确保数据类型为float
dataset = torch.Tensor(s_curve).float()
二、设置超参数
超参数解释:
num_steps
: T T Tbetas
: β \beta βalphas
: α = 1 − β \alpha = 1 - \beta α=1−βalphas_prod
: α ˉ = ∏ i T α i \bar{\alpha} = \prod\limits_{i}^{T}{\alpha_i} αˉ=i∏Tαialphas_bar_sqrt
: α ˉ \sqrt{\bar{\alpha}} αˉone_minus_alphas_bar_log
: log e 1 − α ˉ \log_{e}{1 - \bar{\alpha}} loge1−αˉone_minus_alphas_bar_sqrt
: 1 − α ˉ \sqrt{1 - \bar{\alpha}} 1−αˉ
# 设置扩散过程的总步数
num_steps = 100
# 生成从-6到6的等差数列,用作beta的初始值
betas = torch.linspace(-6, 6, num_steps)
# 使用Sigmoid函数将betas压缩到(0, 1)区间,并进行缩放和平移
# 这样betas的值域在(1e-5, 0.5e-2)之间,这是一个常见的设置,用于控制扩散过程
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
# print(betas.detach().numpy())
# alphas是每一步中未被破坏的信息的比例,计算方式为1减去betas
alphas = 1 - betas
# 计算alphas的累积乘积,得到alpha_bar
alphas_prod = torch.cumprod(alphas, 0) # alphas沿着维度0进行累积乘积
# 创建一个包含初始值1的向量,与alphas_prod的形状相同,用于计算alpha_bar_sqrt,这个后面并没有用到
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
# 计算alphas_prod的平方根,得到alpha_bar_sqrt,用于后续的计算
alphas_bar_sqrt = torch.sqrt(alphas_prod)
# 计算(1 - alphas_prod)的自然对数,这是为了计算one_minus_alphas_bar_sqrt
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
# 计算(1 - alphas_prod)的平方根,得到one_minus_alphas_bar_sqrt,用于后续的计算
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
# 确保所有计算出的alpha相关变量的形状都是相同的,这对于后续的计算很重要
assert alphas.shape == alphas_prod.shape == alphas_prod_p.shape == \
alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == \
one_minus_alphas_bar_sqrt.shape
# 打印betas的形状,确认其长度与num_steps相符
print("all the same shape", betas.shape) # => all the same shape torch.Size([100])
三、扩散函数
x t = α t ˉ ⋅ x 0 + 1 − α t ˉ ⋅ z x_t = \sqrt{\bar{\alpha_t}}\cdot x_0 + \sqrt{1 - \bar{\alpha_t}}\cdot z xt=αtˉ⋅x0+1−αtˉ⋅z
#计算任意时刻的x采样值,基于x_0和重参数化
def q_x(x_0, t):
"""可以基于x[0]得到任意时刻t的x[t]"""
noise = torch.randn_like(x_0) # 从正态分布中进行采样
alphas_t = alphas_bar_sqrt[t] # 得到t时刻的alpha_bar_sqrt
alphas_1_m_t = one_minus_alphas_bar_sqrt[t] # 得到t时刻的one_minus_alphas_bar_sqrt
return (alphas_t * x_0 + alphas_1_m_t * noise) #在x[0]的基础上添加噪声
四、绘制扩散图片
# 设置要显示的图像数量
num_shows = 20
# 创建一个2行10列的子图布局,每个子图用于显示一个图像
fig,axs = plt.subplots(2,10,figsize=(28,3))
# 设置文本颜色为黑色
plt.rc('text',color='black')
#共有10000个点,每个点包含两个坐标
#生成100步以内每隔5步加噪声后的图像
for i in range(num_shows):
# 计算当前图像在子图布局中的位置
# j是行索引,k是列索引
j = i//10 # 整除10得到行索引
k = i%10 # 取余10得到列索引
# i*num_steps//num_shows用于每隔5步生成一个采样数据,得到时间t
one_t = i*num_steps//num_shows
q_i = q_x(dataset,torch.tensor([one_t])) #生成t时刻的采样数据
# 在对应的子图中绘制采样数据的散点图
# 散点的颜色设置为红色,边缘颜色为白色
axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')
# 关闭子图的坐标轴显示
axs[j,k].set_axis_off()
# LaTeX格式的字符串用于更美观地显示数学表达式
axs[j,k].set_title('$q(\mathbf{x}_{'+str(one_t)+'})$')
五、神经网络模型
import torch
import torch.nn as nn
class MLPDiffusion(nn.Module):
def __init__(self, n_steps, num_units=128):
super(MLPDiffusion, self).__init__()
self.linears = nn.ModuleList(
[
nn.Linear(2, num_units),
nn.ReLU(),
nn.Linear(num_units, num_units),
nn.ReLU(),
nn.Linear(num_units, num_units),
nn.ReLU(),
nn.Linear(num_units, 2),
]
)
# 每个嵌入层将时间步t映射到num_units维的向量,对时间步进行编码
# 三个Embedding层可以学习到更复杂的时间步表示。每个嵌入层可以捕捉时间步的不同方面或特征。
self.step_embeddings = nn.ModuleList(
[
nn.Embedding(n_steps,num_units),
nn.Embedding(n_steps,num_units),
nn.Embedding(n_steps,num_units),
]
)
def forward(self,x,t):
for idx, embedding_layer in enumerate(self.step_embeddings):
# 对时间步进行编码
t_embedding = embedding_layer(t)
# 使图片通过线性层
x = self.linears[2*idx](x)
# 将时间步与图片数据进行融合
x += t_embedding
# 使图片通过激活函数层
x = self.linears[2*idx+1](x)
x = self.linears[-1](x)
return x
六、损失函数
L s i m p l e ( θ ) = E t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( α t ˉ x 0 + 1 − α t ˉ ϵ , t ) ∣ ∣ 2 ] L_{simple}(\theta) = E_{t, x_0, \epsilon}[|| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha_t}}x_0 + \sqrt{1 - \bar{\alpha_t}}\epsilon, t)||^2] Lsimple(θ)=Et,x0,ϵ[∣∣ϵ−ϵθ(αtˉx0+1−αtˉϵ,t)∣∣2]
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):
"""
计算扩散模型的损失函数。
参数:
- model: 扩散模型,输入为噪声数据和时间步,输出为去噪数据。
- x_0: 原始数据,维度为[batch_size, 2]。
- alphas_bar_sqrt: alpha的累积乘积的平方根,用于调整噪声水平。
- one_minus_alphas_bar_sqrt: (1 - alpha的累积乘积)的平方根,用于调整噪声水平。
- n_steps: 扩散过程中的总步数。
返回:
- 损失值,表示模型预测与真实噪声之间的差异。
"""
# 得到原始数据中数据的个数
batch_size = x_0.shape[0]
# 对一个batch size的样本生成随机的时刻t,生成一半的步数
t = torch.randint(0, n_steps, size=(batch_size//2,))
# 为了对称性,生成另一半步数的对应时刻
t = torch.cat([t, n_steps-1-t], dim=0)
# 增加一个维度,使其与x_0的维度匹配
t = t.unsqueeze(-1)
# 公式中x0的系数
a = alphas_bar_sqrt[t]
# 公式中epsilon的系数
aml = one_minus_alphas_bar_sqrt[t]
# 从正态分布中,生成与x_0大小一致的随机噪音epsilon
e = torch.randn_like(x_0)
# 构造模型的输入
x = x_0 * a + e * aml
# 送入模型,得到 t 时刻的随机噪声预测值
output = model(x, t.squeeze(-1))
# 与真实噪声一起计算误差,求平均值
return (e - output).square().mean()
七、逆扩散采样函数
x t − 1 = 1 α t ( x t − 1 − α t 1 − α t ˉ ϵ θ ( x t , t ) ) + σ t z x_{t - 1} = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha_t}}}\epsilon_\theta(x_t, t)) + \sigma_tz xt−1=αt1(xt−1−αtˉ1−αtϵθ(xt,t))+σtz
def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
"""
通过扩散模型从噪声数据恢复原始数据序列。
参数:
- model: 扩散模型,用于从噪声数据中预测原始数据。
- shape: 恢复数据的形状,例如(batch_size, num_features)。
- n_steps: 扩散过程中的总步数。
- betas: 每一步的beta值,用于控制噪声水平。
- one_minus_alphas_bar_sqrt: (1 - alphas_prod)的平方根,用于调整噪声水平。
返回:
- x_seq: 包含每一步恢复数据的列表。
"""
# 从标准正态分布生成初始噪声数据
cur_x = torch.randn(shape)
x_seq = [cur_x]
for i in reversed(range(n_steps)): # 反转可迭代序列,从X_T ~ X_0
cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
x_seq.append(cur_x)
return x_seq
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
"""
从给定的噪声数据采样t时刻的重构值。
参数:
- model: 扩散模型,用于预测t时刻的噪声。
- x: 当前时刻的噪声数据。
- t: 当前时间步。
- betas: 每一步的beta值,用于控制噪声水平。
- one_minus_alphas_bar_sqrt: (1 - alphas_prod)的平方根,用于调整噪声水平。
返回:
- sample: t时刻的重构数据样本。
"""
# 转换为tensor类型的数据
t = torch.tensor([t])
# 计算公式中epsilon_theta的系数
coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
# 从模型中进行采样
eps_theta = model(x,t)
# 得到前一个时间步分布的均值
mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
# 从标准正态分布中进行随机采样
z = torch.randn_like(x)
# 得到前一个时间步分布的方差
sigma_t = betas[t].sqrt()
# 得到前一个时间步的数据
sample = mean + sigma_t * z
return (sample)
八、训练模型
# 设置随机种子以保证结果的可重复性
seed = 1234
class EMA():
"""构建一个参数平滑器"""
def __init__(self, mu=0.01):
# 平滑系数
self.mu = mu
# 存储平滑后的参数
self.shadow = {
}
def register(self,name,val):
# 注册参数的初始值
self.shadow[name] = val.clone()
def __call__(self,name,x):
# 根据平滑系数更新参数
assert name in self.shadow
new_average = self.mu * x + (1.0-self.mu)*self.shadow[name]
self.shadow[name] = new_average.clone()
return new_average
print('Training model...')
# 设置批量大小
batch_size = 128
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 设置训练轮数
num_epoch = 4000
# 设置matplotlib文本颜色
plt.rc('text',color='blue')
# 实例化MLPDiffusion模型,num_steps已经定义为100
model = MLPDiffusion(num_steps)
# 实例化优化器
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
for t in range(num_epoch):
# 输出训练进度信息
process = t / num_epoch * 100
if t % 50 == 0:
print(f'{
int(process)}%')
# 每个epoch中的批次循环
for idx,batch_x in enumerate(dataloader):
# 计算扩散损失函数
loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)
# 清除之前的梯度
optimizer.zero_grad()
# 反向传播,计算当前损失的梯度
loss.backward()
# 梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(),1.)
# 使用优化器更新模型参数
optimizer.step()
# 每100轮进行绘图验证一下效果
if(t%100==0):
# 打印当前轮的损失值
print(loss)
# 使用模型生成样本序列
x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)
# 绘制样本序列的图形
fig,axs = plt.subplots(1,10,figsize=(28,3))
for i in range(1,11):
cur_x = x_seq[i*10].detach()
axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');
axs[i-1].set_axis_off();
axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')
九、绘制扩散过程
# 导入所需的库
import io
from PIL import Image
# 初始化一个空列表,用于存储生成的图像
imgs = []
# 循环100次,生成100个图像
for i in range(100):
# 清除当前的绘图区域,为下一次迭代准备
plt.clf()
# 假设q_x是一个函数,它根据数据集dataset和索引i生成数据点
# torch.tensor([i])将整数i转换为一个PyTorch张量
q_i = q_x(dataset, torch.tensor([i]))
# 使用matplotlib在当前清除的绘图区域中绘制散点图
# q_i[:, 0] 和 q_i[:, 1] 分别表示x和y坐标
# color设置点的颜色,edgecolor设置点的边缘颜色,s设置点的大小
plt.scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white', s=5)
# 关闭坐标轴的显示
plt.axis('off')
# 创建一个字节流对象,用于存储图像数据
img_buf = io.BytesIO()
# 将当前绘制的图像保存到字节流中,格式为PNG
plt.savefig(img_buf, format='png')
# 从字节流中打开图像,使用PIL库
img = Image.open(img_buf)
# 将打开的图像添加到imgs列表中
imgs.append(img)
十、绘制逆扩散过程
# 初始化一个空列表,用于存储生成的图像,这个列表将被逆转
reverse = []
# 循环100次,生成100个图像
for i in range(100):
# 清除当前的绘图区域,为绘制新图像做准备
plt.clf()
# x_seq是最后一次采样的时候的所有图片数据
# 这里我们从列表中取出第i个元素,并使用detach()方法从计算图中分离
cur_x = x_seq[i].detach()
# 使用matplotlib绘制散点图
# cur_x[:, 0] 和 cur_x[:, 1] 分别是数据点的x和y坐标
# color设置点的颜色,edgecolor设置点的边缘颜色,s设置点的大小
plt.scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white', s=5)
# 关闭坐标轴的显示,使图像更干净
plt.axis('off')
# 创建一个字节流对象,用于存储图像数据
img_buf = io.BytesIO()
# 将当前绘制的图像保存到字节流中,格式为PNG
plt.savefig(img_buf, format='png')
# 从字节流中使用PIL打开图像
img = Image.open(img_buf)
# 将打开的图像添加到reverse列表中
reverse.append(img)
十一、扩散过程与逆扩散过程相加
imgs = imgs + reverse
十二、扩散过程与逆扩散过程保存为GIF动画
# imgs[0] 表示之前生成的图像列表中的第一个图像
# save 方法用于将图像保存为文件,这里保存为 GIF 格式
imgs[0].save(
"diffusion.gif", # 指定保存的文件名为 "diffusion.gif"
format='GIF', # 指定保存的文件格式为 GIF
append_images=imgs, # 将 imgs 列表中的所有图像追加到第一个图像上
save_all=True, # 指示 PIL 保存所有帧为 GIF 动画
duration=100, # 设置 GIF 中每帧的持续时间(以毫秒为单位),这里是 100 毫秒
loop=0 # 设置 GIF 的循环次数,0 表示无限循环
)
附:论文中的神经网络模型
模型的整体结构
SiLU激活函数和归一化的选择
class SiLU(nn.Module):
# SiLU激活函数
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
def get_norm(norm, num_channels, num_groups):
if norm == "in":
return nn.InstanceNorm2d(num_channels, affine=True)
elif norm == "bn":
return nn.BatchNorm2d(num_channels)
elif norm == "gn":
return nn.GroupNorm(num_groups, num_channels)
elif norm is None:
return nn.Identity()
else:
raise ValueError("unknown normalization type")
位置编码层
P E ( p o s , 2 i ) = sin ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ( p o s / 1000 0 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin{(pos / 10000^{2i/d_{model}})} \\ PE_{(pos, 2i + 1)} = \cos{(pos / 10000^{2i/d_{model}})} PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)
class PositionalEmbedding(nn.Module):
def __init__(self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.scale = scale
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
# x * self.scale和emb外积
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
注意力模块
class AttentionBlock(nn.Module):
def __init__(self, in_channels, norm="gn", num_groups=32):
super().__init__()
self.in_channels = in_channels
self.norm = get_norm(norm, in_channels, num_groups)
# 定义一个卷积层,将输入通道数扩展为 3 倍输入通道数,用于生成查询(Q)、键(K)和值(V)。
self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
self.to_out = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
b, c, h, w = x.shape
# 将图片数据的通道提升三倍,然后在通道的维度上分割为Q、K、V三个矩阵
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
q = q.permute(0, 2, 3, 1).view(b, h * w, c)
k = k.view(b, c, h * w)
v = v.permute(0, 2, 3, 1).view(b, h * w, c)
# torch.bmm接受两个输入张量,并将它们视为一系列矩阵,然后对这些矩阵进行乘法运算。
dot_products = torch.bmm(q, k) * (c ** (-0.5))
assert dot_products.shape == (b, h * w, h * w)
attention = torch.softmax(dot_products, dim=-1)
out = torch.bmm(attention, v)
assert out.shape == (b, h * w, c)
out = out.view(b, h, w, c).permute(0, 3, 1, 2)
return self.to_out(out) + x
残差模块
class ResidualBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout,
time_emb_dim=None,
num_classes=None,
activation=SiLU(), # 默认使用SiLU激活函数
norm="gn", # 默认使用组归一化
num_groups=32,
use_attention=False,
):
super().__init__()
# 激活函数
self.activation = activation
# 归一化层
self.norm_1 = get_norm(norm, in_channels, num_groups)
self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm_2 = get_norm(norm, out_channels, num_groups)
self.conv_2 = nn.Sequential(
nn.Dropout(p=dropout),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
)
self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None
self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
# nn.Identity() 是一个特殊的模块,它实现了一个恒等函数,即输入直接等于输出。
self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
def forward(self, x, time_emb=None, y=None):
# 归一化层 + 激活函数 + 卷积层
out = self.activation(self.norm_1(x))
out = self.conv_1(out)
# 对时间time_emb做一个全连接,施加在通道上
if self.time_bias is not None:
if time_emb is None:
raise ValueError("time conditioning was specified but time_emb is not passed")
# 激活函数 + 线性层
out += self.time_bias(self.activation(time_emb))[:, :, None, None]
# 对种类y_emb做一个全连接,施加在通道上
if self.class_bias is not None:
if y is None:
raise ValueError("class conditioning was specified but y is not passed")
out += self.class_bias(y)[:, :, None, None]
# 归一化层 + 激活函数
out = self.activation(self.norm_2(out))
# 第二个卷积 + 残差边
out = self.conv_2(out) + self.residual_connection(x)
# 最后做个Attention
out = self.attention(out)
return out
下采样层
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
# 使用一个大小为(3, 3),步长为 2 的卷积来进行下采样
self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
def forward(self, x, time_emb, y):
# 判断高度是否为偶数个像素
if x.shape[2] % 2 == 1:
raise ValueError("downsampling tensor height should be even")
# 判读宽度是否为偶数个像素
if x.shape[3] % 2 == 1:
raise ValueError("downsampling tensor width should be even")
return self.downsample(x)
上采样层
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.upsample = nn.Sequential(
# 上采样层使用最近邻插值模式
nn.Upsample(scale_factor=2, mode="nearest"),
# 卷积层
nn.Conv2d(in_channels, in_channels, 3, padding=1),
)
def forward(self, x, time_emb, y):
return self.upsample(x)
U-Net网络
class UNet(nn.Module):
def __init__(
self,
img_channels,
base_channels=128,
channel_mults=(1, 2, 4, 8),
num_res_blocks=3,
time_emb_dim=128 * 4,
time_emb_scale=1.0,
num_classes=None,
activation=SiLU(),
dropout=0.1,
attention_resolutions=(1,),
norm="gn",
num_groups=32,
initial_pad=0,
):
super().__init__()
# 使用到的激活函数,一般为SiLU
self.activation = activation
# 是否对输入进行padding
self.initial_pad = initial_pad
# 需要去区分的类别数
self.num_classes = num_classes
# 时间步 t 编码连接层
self.time_mlp = nn.Sequential(
PositionalEmbedding(base_channels, time_emb_scale),
nn.Linear(base_channels, time_emb_dim),
SiLU(),
nn.Linear(time_emb_dim, time_emb_dim),
) if time_emb_dim is not None else None
# 对输入图片的第一个卷积
self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)
# self.downs 用于存储下采样用到的所有模块
# self.ups 用于存储上采样用到的所有模块
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
# channels 指的是每一个模块处理后的通道数
# now_channels 是一个中间变量,代表中间的通道数
channels = [base_channels]
now_channels = base_channels
for i, mult in enumerate(channel_mults):
# 得到下采样后的每一个通道个数
out_channels = base_channels * mult
# 进行添加每一个resblockattn
for _ in range(num_res_blocks):
self.downs.append(
ResidualBlock(
now_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
)
)
now_channels = out_channels
channels.append(now_channels)
if i != len(channel_mults) - 1:
self.downs.append(Downsample(now_channels))
channels.append(now_channels)
# 中间层的特征提取模块,直接添加两个residual block模块
self.mid = nn.ModuleList(
[
ResidualBlock(
now_channels, now_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=True,
),
ResidualBlock(
now_channels, now_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=False,
),
]
)
# 进行上采样,进行特征融合
for i, mult in reversed(list(enumerate(channel_mults))):
out_channels = base_channels * mult
for _ in range(num_res_blocks + 1):
self.ups.append(ResidualBlock(
channels.pop() + now_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
))
now_channels = out_channels
if i != 0:
self.ups.append(Upsample(now_channels))
assert len(channels) == 0
self.out_norm = get_norm(norm, base_channels, num_groups)
self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
def forward(self, x, time=None, y=None):
# 是否对输入进行padding
ip = self.initial_pad
if ip != 0:
x = F.pad(x, (ip,) * 4)
# 对时间轴输入的全连接层
if self.time_mlp is not None:
if time is None:
raise ValueError("time conditioning was specified but tim is not passed")
time_emb = self.time_mlp(time)
else:
time_emb = None
if self.num_classes is not None and y is None:
raise ValueError("class conditioning was specified but y is not passed")
# 对输入图片的第一个卷积
x = self.init_conv(x)
# skips用于存放下采样的中间层
skips = [x]
for layer in self.downs:
x = layer(x, time_emb, y)
skips.append(x)
# 特征整合与提取
for layer in self.mid:
x = layer(x, time_emb, y)
# 上采样并进行特征融合
for layer in self.ups:
if isinstance(layer, ResidualBlock):
x = torch.cat([x, skips.pop()], dim=1)
x = layer(x, time_emb, y)
# 上采样并进行特征融合
x = self.activation(self.out_norm(x))
x = self.out_conv(x)
if self.initial_pad != 0:
return x[:, :, ip:-ip, ip:-ip]
else:
return x
完整的模型结构
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SiLU(nn.Module):
# SiLU激活函数
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
def get_norm(norm, num_channels, num_groups):
if norm == "in":
return nn.InstanceNorm2d(num_channels, affine=True)
elif norm == "bn":
return nn.BatchNorm2d(num_channels)
elif norm == "gn":
return nn.GroupNorm(num_groups, num_channels)
elif norm is None:
return nn.Identity()
else:
raise ValueError("unknown normalization type")
class PositionalEmbedding(nn.Module):
def __init__(self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.scale = scale
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
# x * self.scale和emb外积
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
def forward(self, x, time_emb, y):
if x.shape[2] % 2 == 1:
raise ValueError("downsampling tensor height should be even")
if x.shape[3] % 2 == 1:
raise ValueError("downsampling tensor width should be even")
return self.downsample(x)
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_channels, in_channels, 3, padding=1),
)
def forward(self, x, time_emb, y):
return self.upsample(x)
class AttentionBlock(nn.Module):
def __init__(self, in_channels, norm="gn", num_groups=32):
super().__init__()
self.in_channels = in_channels
self.norm = get_norm(norm, in_channels, num_groups)
self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
self.to_out = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
b, c, h, w = x.shape
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
q = q.permute(0, 2, 3, 1).view(b, h * w, c)
k = k.view(b, c, h * w)
v = v.permute(0, 2, 3, 1).view(b, h * w, c)
dot_products = torch.bmm(q, k) * (c ** (-0.5))
assert dot_products.shape == (b, h * w, h * w)
attention = torch.softmax(dot_products, dim=-1)
out = torch.bmm(attention, v)
assert out.shape == (b, h * w, c)
out = out.view(b, h, w, c).permute(0, 3, 1, 2)
return self.to_out(out) + x
class ResidualBlock(nn.Module):
def __init__(
self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=SiLU(),
norm="gn", num_groups=32, use_attention=False,
):
super().__init__()
self.activation = activation
self.norm_1 = get_norm(norm, in_channels, num_groups)
self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm_2 = get_norm(norm, out_channels, num_groups)
self.conv_2 = nn.Sequential(
nn.Dropout(p=dropout),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
)
self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None
self.residual_connection = nn.Conv2d(in_channels, out_channels,
1) if in_channels != out_channels else nn.Identity()
self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
def forward(self, x, time_emb=None, y=None):
out = self.activation(self.norm_1(x))
# 第一个卷积
out = self.conv_1(out)
# 对时间time_emb做一个全连接,施加在通道上
if self.time_bias is not None:
if time_emb is None:
raise ValueError("time conditioning was specified but time_emb is not passed")
out += self.time_bias(self.activation(time_emb))[:, :, None, None]
# 对种类y_emb做一个全连接,施加在通道上
if self.class_bias is not None:
if y is None:
raise ValueError("class conditioning was specified but y is not passed")
out += self.class_bias(y)[:, :, None, None]
out = self.activation(self.norm_2(out))
# 第二个卷积+残差边
out = self.conv_2(out) + self.residual_connection(x)
# 最后做个Attention
out = self.attention(out)
return out
class UNet(nn.Module):
def __init__(
self, img_channels, base_channels=128, channel_mults=(1, 2, 4, 8),
num_res_blocks=3, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=SiLU(),
dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
):
super().__init__()
# 使用到的激活函数,一般为SILU
self.activation = activation
# 是否对输入进行padding
self.initial_pad = initial_pad
# 需要去区分的类别数
self.num_classes = num_classes
# 对时间轴输入的全连接层
self.time_mlp = nn.Sequential(
PositionalEmbedding(base_channels, time_emb_scale),
nn.Linear(base_channels, time_emb_dim),
SiLU(),
nn.Linear(time_emb_dim, time_emb_dim),
) if time_emb_dim is not None else None
# 对输入图片的第一个卷积
self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)
# self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征
# 然后利用Downsample降低特征图的高宽
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
# channels指的是每一个模块处理后的通道数
# now_channels是一个中间变量,代表中间的通道数
channels = [base_channels]
now_channels = base_channels
for i, mult in enumerate(channel_mults):
out_channels = base_channels * mult
for _ in range(num_res_blocks):
self.downs.append(
ResidualBlock(
now_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
)
)
now_channels = out_channels
channels.append(now_channels)
if i != len(channel_mults) - 1:
self.downs.append(Downsample(now_channels))
channels.append(now_channels)
# 可以看作是特征整合,中间的一个特征提取模块
self.mid = nn.ModuleList(
[
ResidualBlock(
now_channels, now_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=True,
),
ResidualBlock(
now_channels, now_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=False,
),
]
)
# 进行上采样,进行特征融合
for i, mult in reversed(list(enumerate(channel_mults))):
out_channels = base_channels * mult
for _ in range(num_res_blocks + 1):
self.ups.append(ResidualBlock(
channels.pop() + now_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
))
now_channels = out_channels
if i != 0:
self.ups.append(Upsample(now_channels))
assert len(channels) == 0
self.out_norm = get_norm(norm, base_channels, num_groups)
self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
def forward(self, x, time=None, y=None):
# 是否对输入进行padding
ip = self.initial_pad
if ip != 0:
x = F.pad(x, (ip,) * 4)
# 对时间轴输入的全连接层
if self.time_mlp is not None:
if time is None:
raise ValueError("time conditioning was specified but tim is not passed")
time_emb = self.time_mlp(time)
else:
time_emb = None
if self.num_classes is not None and y is None:
raise ValueError("class conditioning was specified but y is not passed")
# 对输入图片的第一个卷积
x = self.init_conv(x)
# skips用于存放下采样的中间层
skips = [x]
for layer in self.downs:
x = layer(x, time_emb, y)
skips.append(x)
# 特征整合与提取
for layer in self.mid:
x = layer(x, time_emb, y)
# 上采样并进行特征融合
for layer in self.ups:
if isinstance(layer, ResidualBlock):
x = torch.cat([x, skips.pop()], dim=1)
x = layer(x, time_emb, y)
# 上采样并进行特征融合
x = self.activation(self.out_norm(x))
x = self.out_conv(x)
if self.initial_pad != 0:
return x[:, :, ip:-ip, ip:-ip]
else:
return x