U-Net架构

​​​​​​基本了解

UNet是一种经典的卷积神经网络架构,解决了传统方法在数据量不足时面临的挑战。最初由医学图像分割任务提出,后被广泛应用于扩散模型(如DDPM、DDIM、Stable Diffusion)中作为噪声预测的核心网络。

核心结构包括一个收缩路径(downsampling path)和一个对称的扩展路径(upsampling path)。收缩路径通过多次下采样操作捕获上下文信息,而扩展路径则通过上采样操作结合底层特征和高层特征,实现精确的像素级分割。这种U形结构设计使其能够高效地利用有限的标注样本,并在现代GPU上快速执行。

以下是UNet的详细构成及其在扩散模型中的改进设计:


一、 UNet基础架构

1. 整体结构

UNet呈对称的U型结构,包含:

  • 编码器(下采样路径):逐步提取高层语义特征,降低分辨率。

  • 解码器(上采样路径):逐步恢复空间分辨率,结合编码器特征进行精确定位。

  • 跳跃连接(Skip Connections)连接编码器与解码器对应层,保留细节信息。

2. 核心组件
层级 操作 作用
编码器层 卷积 → 激活(如ReLU) → 池化/步长卷积 提取特征,压缩空间维度
解码器层 反卷积/插值上采样 → 跳跃连接 → 卷积 恢复分辨率,融合低级与高级特征
瓶颈层 深层卷积操作 捕获全局上下文信息

二、 扩散模型中UNet的改进

在扩散模型(如DDPM、DDIM)中,UNet经过以下关键改进:

1. 时间步嵌入(Timestep Embedding)
  • 作用:将扩散过程的时间步信息注入网络,指导噪声预测。

  • 实现

    • 时间步编码为向量(通过正弦位置编码或MLP)。

    • 通过相加或拼接融入各层特征图。

2. 残差块(Residual Blocks)
  • 结构:每个块包含多个卷积层 + 归一化层(如GroupNorm) + 激活函数

  • 改进

    • 引入残差连接,缓解梯度消失。

    • 集成时间步嵌入(通过相加或自适应归一化)。

3. 注意力机制(Attention Layers)
  • 位置:通常在瓶颈层或解码器中插入。

  • 类型

    • 自注意力(Self-Attention):捕捉长程依赖。

    • 交叉注意力(Cross-Attention):用于多模态模型(如Stable Diffusion中结合文本提示)。

4. 分组归一化(Group Normalization)
  • 替代方案:相比BatchNorm,更适合小批量训练,提升稳定性。

5. 多尺度特征融合
  • 跳跃连接增强:通过通道注意力(如Squeeze-and-Excitation)动态加权特征。


三、 典型扩散模型UNet结构示例

以Stable Diffusion为例,其UNet结构参数如下:

- 输入:带噪声的潜空间特征图(64×64×4)
- 编码器:4个下采样层,每层包含2个残差块
- 瓶颈层:2个残差块 + 自注意力层
- 解码器:4个上采样层,每层包含2个残差块
- 跳跃连接:逐层传递编码器特征至解码器
- 总参数量:约860M(Stable Diffusion 1.4版本)

四、 关键设计思想

  1. 分辨率保持:通过跳跃连接保留低级细节,避免上采样时的模糊问题。

  2. 动态条件注入:时间步嵌入和文本条件(如CLIP embedding)通过自适应归一化(AdaGN)融入网络:

    • # 示例:自适应归一化(AdaGN)
      def adaptive_group_norm(x, timestep_emb, scale_shift=True):
          scale, shift = timestep_emb.chunk(2, dim=1)
          x = GroupNorm(x)
          if scale_shift:
              x = x * (1 + scale) + shift
          return x
  3. 轻量化设计

    • 使用深度可分离卷积(Depthwise Separable Conv)减少计算量。

    • 通道数动态调整(如Stable Diffusion中通道数从128到1024递增)。


五、 与传统UNet的区别

特性 传统UNet(医学分割) 扩散模型UNet
输入/输出 原始图像 → 分割掩码 噪声图像 + 时间步 → 噪声残差
归一化方式 BatchNorm GroupNorm
条件注入 时间步嵌入 + 文本/图像条件
注意力机制 自注意力/交叉注意力
参数量级 较小(几M~几十M) 较大(几百M~上B)

六、 代码框架示例(PyTorch风格)

class DiffusionUNet(nn.Module):
    def __init__(self):
        super().__init__()
        # 编码器
        self.encoder = nn.ModuleList([
            DownBlock(3, 64),          # 下采样块
            DownBlock(64, 128),
            DownBlock(128, 256)
        ])
        
        # 瓶颈层(含注意力)
        self.bottleneck = nn.Sequential(
            ResBlock(256, 512),
            SelfAttention(512),
            ResBlock(512, 512)
        )
        
        # 解码器
        self.decoder = nn.ModuleList([
            UpBlock(512, 256),          # 上采样块(含跳跃连接)
            UpBlock(256, 128),
            UpBlock(128, 64)
        ])
        
        # 时间步嵌入
        self.time_embed = nn.Sequential(
            nn.Linear(128, 256),
            nn.SiLU(),
            nn.Linear(256, 256)
        )

    def forward(self, x, t):
        t_emb = self.time_embed(t)      # 时间步嵌入
        skips = []
        for down in self.encoder:
            x = down(x, t_emb)          # 注入时间条件
            skips.append(x)
        x = self.bottleneck(x)
        for up in self.decoder:
            x = up(x, skips.pop(), t_emb)
        return x

DDPM中的应用

在扩散模型(如DDPM)中,改进的UNet结构通过以下方式整合时间嵌入(time embedding),实现对噪声的精准预测:

1. UNet的基础结构

编码器-解码器架构:UNet由对称的下采样(编码器)和上采样(解码器)路径组成,通过跳跃连接保留多尺度特征。

残差块(ResBlock):每个下采样和上采样阶段包含多个残差块,用于特征提取。

2. Time Embedding的作用

时间步编码:扩散过程中的时间步 t 被编码为高维向量(如通过正弦函数或MLP),表示当前加噪阶段。

动态调节网络:Time embedding作为条件信号,影响每一层的计算,使网络适应不同时间步的噪声分布。

3. Time Embedding的注入方式

特征图加法:Time embedding通过线性层投影后,直接添加到残差块的特征图中

自适应归一化(AdaGN):

4. 输入与输出的设计

输入:加噪图像在时间步t时由原始图像逐步添加高斯噪声得到。

输出:预测的噪声ϵθ (xt,t),目标是最小化与真实噪声的误差

5. 关键理解点

条件化每一层:每个残差块均接收相同的time embedding,确保网络在不同时间步采用不同的特征变换策略。

时间感知的噪声预测:通过时间嵌入,模型能区分早期(大尺度噪声)和晚期(细节噪声)的去噪需求,提升生成质量。

6. 扩散模型中的应用

前向过程:逐步为图像添加噪声

反向过程:UNet预测噪声,通过迭代去噪重建

猜你喜欢

转载自blog.csdn.net/m0_63855028/article/details/146296160