苹果重磅开源俄罗斯套娃扩散模型!MDM:多任务高分辨率生成又快又好!

文章链接:https://arxiv.org/pdf/2310.15111
项目链接:https://github.com/apple/ml-mdm

亮点直击

  • 提出了Matryoshka Diffusion Models (MDM),通过联合处理多个分辨率的输入,避免了级联或潜在扩散方法的复杂性,并引入了Nested UNet架构,嵌套不同尺度的特征和参数,提升了高分辨率生成的效果。

  • 采用了多分辨率损失,显著加速了高分辨率去噪的收敛速度,同时使用渐进式训练策略,从低分辨率开始逐步引入高分辨率内容,实现了训练成本和生成质量的良好平衡。

  • 在多个生成任务中表现优异,包括类条件图像生成、文本到图像生成和视频生成,尤其是在1024×1024像素分辨率下,使用较小的数据集(CC12M)依然取得了高质量生成结果。

  • MDM具备强大的泛化能力,不仅在图像生成中表现突出,还自然扩展到视频生成,展现了广泛的应用潜力。

总结速览

解决的问题:

  • 生成高质量图像和视频的扩散模型面临高维度学习的计算和优化挑战,尤其是在处理高分辨率图像时。

提出的方案:

  • 引入 Matryoshka Diffusion Model (MDM) ,一种新颖的高分辨率图像和视频合成框架,采用联合降噪的扩散过程,在多个分辨率下处理输入。

  • 使用NestedUNet架构,在大尺度输入的特征和参数中嵌套小尺度的特征和参数。

  • 设计了一种从低分辨率到高分辨率的渐进式训练方式,优化高分辨率生成的效果。

应用的技术:

  • MDM通过多分辨率联合降噪的扩散过程,使用NestedUNet架构,在像素空间直接进行高分辨率生成。

  • 使用渐进式训练策略,逐步从低分辨率到高分辨率,解决了高分辨率生成的优化问题。

达到的效果:

  • MDM在多个基准测试中表现出色,包括类条件图像生成高分辨率文本到图像生成文本到视频生成任务。

  • 成功训练了单一的像素空间模型,最高分辨率可达1024×1024像素。

  • 在使用仅包含1200万张图像的CC12M数据集上展示了强大的零样本泛化能力。

由 MDM 生成的图像分辨率分别为、、、和,使用提示符“穿日本和服的鹿宝宝套娃,超级细节,极其逼真,8K”; 用我们的方法生成的的视频1帧和16帧,使用提示符“将牛奶倒入黑咖啡”; 所有其他样本分别为。图像调整了大小以便于可视化。

Matryoshka扩散模型

本节介绍了Matryoshka扩散模型(MDM),一个新型的扩散模型类别,在高分辨率空间中进行训练,同时利用数据生成的层次结构。MDM首先在扩展空间中推广了标准扩散模型,为此提出了专门的嵌套架构和训练过程。

扩展空间中的扩散模型

与级联或潜在方法不同,MDM通过在扩展空间中引入多分辨率扩散过程,学习一个具有层次结构的单一扩散过程。下图2展示了该过程的示意图。

给定一个数据点 ,定义时间相关的潜变量 。类似于公式(1),对于每个 ,

其中, 是一个依赖于数据的确定性“下采样”算子。 是 的粗略/有损压缩版本。例如, 可以是用于生成低分辨率图像的 avgpool(.)。

默认情况下,假设逐步压缩方式满足 ,且 。此外, 是特定于分辨率的噪声计划。本文遵循 Gu et al. (2022) 的方法,并根据输入的分辨率调整噪声计划。MDM 通过 个神经去噪器 学习后向过程 。每个变量 在时间步 时依赖于所有分辨率 。在推理过程中,MDM 并行生成所有 个分辨率。各个 之间没有依赖关系。

在扩展空间中建模扩散过程具有明显的优点:

  • 由于推理过程中我们关心的是全分辨率输出 ,其他所有中间分辨率都被视为额外的隐藏变量 ,丰富了所建模分布的复杂性;

  • 多分辨率依赖性为 之间的权重和计算共享提供了机会,能够更有效地重新分配训练和推理中的计算资源。

NestedUNet架构

与典型的扩散模型类似,本文以UNet的形式实现了MDM:跳跃连接与计算块并行使用,以保留细粒度的输入信息,计算块由多级卷积和自注意力层组成。在MDM中,根据渐进压缩的假设, 的计算自然也对 有益。提出了NestedUNet架构,它将所有分辨率 的潜变量组合在一个去噪函数中作为嵌套结构,其中低分辨率的潜变量将随着标准的下采样逐步输入。这种多尺度计算共享极大地简化了高分辨率生成的学习过程。NestedUNet 与标准 UNet 的伪代码如下。

除了相对于其他层次化方法的简洁性,NestedUNet 还允许以最有效的方式分配计算。如下图 3 所示,早期探索发现,当将大部分参数和计算分配到最低分辨率时,MDM 在可扩展性方面表现得更好。

训练

研究者们在多个分辨率上联合使用正常的去噪目标来训练 MDM,具体如下:

其中, 是特定于分辨率的权重,默认情况下我们设置 。

渐进式训练

虽然 MDM 可以直接按照公式 (3) 进行端到端训练,已经显示出比简单基线更好的收敛性,但发现一个简单的渐进式训练技术,类似于 GAN 文献中提出的,大大加快了高分辨率模型的训练速度,特别是在墙钟时间方面。更具体地,将训练分为 个阶段,在这些阶段中,逐步将更高的分辨率加入到公式 (3) 的训练目标中。这相当于在 上学习一系列的 MDM,直到 达到最终分辨率。得益于所提出的架构,可以轻松实现上述目标,就像逐渐扩展网络一样。这种训练方案避免了从一开始就进行高分辨率训练,并加快了整体收敛速度。

实验

MDM 是一种通用技术,适用于任何输入维度可以逐步压缩的问题。考虑了两个超出类条件图像生成的应用,展示了本文方法的有效性——文本到图像和文本到视频生成。

实验设置

数据集
本文只关注公开可用且易于重现的数据集。对于图像生成,在 ImageNet上进行了类条件生成,分辨率为 ,并使用 Conceptual 12M 进行了通用文本到图像生成,分辨率为 和 。作为通用性的附加证据,展示了在 WebVid-10M 上进行的文本到视频生成结果,分辨率为 。在附录 F 中列出了数据集和预处理的详细信息。

在论文中广泛依赖 CC12M 作为文本到图像生成模型的数据集,显著不同于以往依赖极大且有时无法获取的数据集的研究。CC12M 足以构建高质量的文本到图像模型,并具备强大的zero-shot 能力,训练时间相对较短。这使得社区能够对方法进行更一致的比较,因为该数据集是免费提供的,且训练时间是可行的。因此 CC12M 更适合作为该问题研究的共同训练和评估基线。

评估
根据以往的研究,使用 Fréchet Inception Distance(ImageNet, CC12M)和 CLIP 分数(CC12M)来评估本文的图像生成模型。为了检查它们的zero-shot 能力,我们还报告了使用 COCO(Lin et al., 2014)验证集生成图像的 FID/CLIP 分数,使用 CC12M 训练的模型。我们还在补充材料中提供了图像和视频合成的其他定性样本。

实现细节
根据所提出的 NestedUNet 架构实现 MDM,最内层的 UNet 分辨率设置为 。类似于 Podell et al. (2023),将大部分自注意力层移动到较低的特征层(),最终内 UNet 总共有 450M 参数。如前所述,高分辨率部分可以很容易地附加到 NestedUNet 的前一层,且参数数量增加很小。对于文本到图像和文本到视频模型,使用冻结的 FLAN-T5 XL 作为文本编码器,因为它的规模适中,语言编码性能良好。此外,对文本表示应用了两个可学习的自注意力层,以增强文本与图像的对齐。

对于图像生成任务,对 的 MDM 进行的实验为 ,对 的实验为 。

对于视频生成,MDM 采用相同的图像 UNet 嵌套,并添加了注意力层以学习时间动态。整体分辨率为 。我们对空间进行双线性插值 ,对时间进行第一帧索引 。除非另有说明,我们对所有 MDM 应用渐进式和混合分辨率训练。使用了 8 台 A100 GPU 进行 ImageNet 的训练,使用 32 台 A100 GPU 进行 CC12M 和 WebVid-10M 的训练。

基线模型 除了与现有最先进方法的比较外,还在控制设置下对 MDM 与三个基线模型进行了详细分析:

  1. 简单 DM:将标准 UNet 架构直接应用于高分辨率输入;我们还考虑了 Nested UNet 架构,但忽略低分辨率损失;这两种情况本质上与最近的端到端扩散模型如 Hoogeboom 等(2023)相同。

  2. 级联 DM:遵循 Saharia 等(2022)的实现细节,训练一个直接与 MDM 可比较的 CDM,其中上采样器的配置与我们的 NestedUNet 相同。我们还对低分辨率条件图像应用噪声增强,并在推理过程中遍历最佳噪声水平。

  3. 潜在 DM:利用来自 Rombach 等(2022)自动编码器的潜在编码,随后训练与 MDM UNet 维度匹配的扩散模型。

主要结果

与基线方法的比较 与基线的比较结果如下图 4 所示。在 ImageNet 上,我们选择标准 UNet 作为我们的简单 DM 基线。对于级联 DM 基线,预训练一个 的扩散模型,训练 200K 次迭代,并应用同样大小的上采样 UNet。在推理时应用标准噪声增强,并遍历最佳噪声水平(我们发现这一点至关重要)。对于 LDM 实验,使用 Rombach 等 预训练的自动编码器,该编码器将输入分辨率下采样,我们在这些实验中使用与我们的 低分辨率模型相同的架构。对于 MDM 变体,使用与基线 UNet 相同大小的 NestedUNet。我们实验了两个变体,一个是直接使用多分辨率损失(式(3))进行训练(标记为 no PT),另一个是从 扩散模型恢复训练(即渐进式训练)。CC12M 的设置类似,只是我们使用单一损失的 NestedUNet 作为我们的简单 DM 架构。监控 ImageNet 上的 FID 曲线,以及 CC12M 上的 FID 和 CLIP 曲线。

比较简单 DM 和 MDM,可以清晰地看到 MDM 收敛速度更快,最终性能更好。这表明,多分辨率扩散过程结合多分辨率损失有效地改善了模型的收敛性,同时带来的复杂性微乎其微。当遵循渐进式训练计划时,我们看到 MDM 的性能和收敛速度进一步提高。作为直接比较,发现级联 DM 基线显著低于 MDM 的表现,尽管两者都从相同的 模型开始。需要注意的是,这一点非常显著,因为级联 DM 的参数总数大于 MDM(因为 MDM 在不同分辨率之间有广泛的参数共享),且推理步骤是其两倍。假设级联 DM 性能较差的主要原因在于 模型没有经过严格训练,这导致训练与推理之间在条件输入方面存在较大差距。最后,与 LDM 相比,MDM 的性能也更好。尽管这不是一个直接的对比,因为 LDM 确实由于其小输入尺寸而更有效,但 MDM 的训练和推理管道更为简单。

与文献的比较 在下表 1 中,MDM 与现有文献中的方法进行了比较,报告了 ImageNet 的 FID-50K 和 MSCOCO 的zero-shot FID-30K。对于 ImageNet,我们的架构和超参数并没有经过优化,MDM 能够达到 3.51 的竞争性 FID 值,与 CFG 相比。我们的 FID 结果与文献相当,尽管 MDM 在训练时使用的数据量明显少于基线模型,如 Imagen 和 DALL·E 2。

定性结果 下面展示训练后的 MDM 随机样本,用于图像生成(ImageNet ,下图 5)、文本到图像生成(CC12M, 下图 6)和文本到视频生成(WebVid-10M,下图 7)。尽管在相对较小的数据集上进行训练,MDM 展现出强大的zero-shot 能力,能够生成高分辨率的图像和视频。值得注意的是,对所有三个任务使用相同的训练流程,表明其处理各种数据类型的多样化能力。

消融研究

渐进式训练的效果 实验了渐进式训练计划,在该计划中,改变了低分辨率模型在继续训练目标分辨率之前的训练迭代次数(下图 8a)。看到更多的低分辨率训练明显有利于高分辨率的 FID 曲线。需要注意的是,在低分辨率输入上进行训练在内存和时间复杂度方面更为高效,因此渐进式训练为在训练过程中寻找最佳计算权衡提供了一种直接的选择。

嵌套层数的效果 接下来,比较了在 CC12M 上使用不同数量嵌套分辨率的性能。结果如上图 8b 所示。从两个分辨率层增加到三个分辨率层始终改善了模型的收敛性。值得注意的是,增加嵌套层数仅带来了微不足道的成本。

CLIP-FID 权衡 最后,在上图 8c 中展示了 COCO 的zero-shot 评估中 CLIP-FID 的帕累托曲线,这是通过改变无分类器引导(CFG)权重实现的。MDM 对 CFG 的响应与其他扩散模型变体类似。作为比较,叠加了 Imagen 报告的相同图(下图 A.11)。Imagen 通常展示出更小的 FID,这归因于其在大数据集上训练所导致的更高多样性。然而,MDM 展示出强大的 CLIP 得分,而在实践中发现,这种高 CLIP 得分与生成图像的视觉质量有很好的相关性。

讨论与未来方向

本文展示了跨不同分辨率共享表示可以加快训练速度并获得高质量结果,尤其是当低分辨率首先被训练时。相信这是因为模型能够更有效地利用不同分辨率之间的相关性,既在空间上也在时间上。尽管在这里仅探讨了一小部分架构,预计通过对权重共享架构的更详细探索,以及在当前架构中不同分辨率间参数分配的新方法,可以实现更多改进。

本文工作的另一个独特方面是使用了扩展空间,其中在多个分辨率上同时进行去噪。在这种形式下,时间和空间上的分辨率以相同的方式处理,时间和空间中的相关结构差异由权重共享模型的不同参数学习。对多分辨率的联合优化进行更一般的概念化的方法是将不同分辨率的损失解耦,并给予它们不同的权重。可以设想,在训练低分辨率到高分辨率的过程中可以实现平滑的过渡。

还注意到,尽管本文将方法与 LDM 进行了比较,这些方法是互补的。可以在自编码器代码的基础上构建 MDM。尽管并没有声称基于 MDM 的模型达到了最先进的水平,但将对 MDM 在大规模数据集和模型规模上的评估留作未来工作。

参考文献

[1] Matryoshka Diffusion Models