传统 CNFs 的训练方法:最大似然估计与 ODE 模拟以及Flow Matching“无仿真”(simulation-free)

解释传统连续归一化流(Continuous Normalizing Flows, CNFs)训练方法(以最大似然估计为例)为什么需要昂贵的 ODE 模拟,以及 Flow Matching(FM)提出的“无仿真”(simulation-free)方法是如何绕过这一问题的。我们会从原理、公式和计算过程逐步展开。


传统 CNFs 的训练方法:最大似然估计与 ODE 模拟

原理

连续归一化流(CNFs)通过一个时间依赖的向量场 ( v t ( x ; θ ) v_t(x; \theta) vt(x;θ) )(由神经网络参数化)定义一个从简单分布(如标准正态分布 ( p 0 ( x ) = N ( 0 , I ) p_0(x) = \mathcal{N}(0, I) p0(x)=N(0,I) ))到目标数据分布 ( q ( x ) q(x) q(x) ) 的连续变换。这个变换由以下常微分方程(ODE)描述:

d d t ϕ t ( x ) = v t ( ϕ t ( x ) ; θ ) , ϕ 0 ( x ) = x \frac{d}{dt} \phi_t(x) = v_t(\phi_t(x); \theta), \quad \phi_0(x) = x dtdϕt(x)=vt(ϕt(x);θ),ϕ0(x)=x

  • ( ϕ t ( x ) \phi_t(x) ϕt(x) ):流映射,从 ( t = 0 t=0 t=0 ) 到 ( t = 1 t=1 t=1 ) 将 ( x x x) 从初始点变换到目标点。
  • ( p t ( x ) = [ ϕ t ] ∗ p 0 ( x ) p_t(x) = [\phi_t]_* p_0(x) pt(x)=[ϕt]p0(x) ):概率密度随时间 ( t t t ) 的演化。

最终,( p 1 ( x ) p_1(x) p1(x) ) 应逼近目标分布 ( q ( x ) q(x) q(x) )。

传统 CNFs 的训练目标是通过最大似然估计(Maximum Likelihood Estimation, MLE)优化参数 ( θ \theta θ ),使模型生成的 ( p 1 ( x ) p_1(x) p1(x) ) 与数据分布 ( q ( x ) q(x) q(x) ) 尽可能接近。

最大似然估计的目标

给定数据样本 ( { x ( i ) } i = 1 N ∼ q ( x ) \{x^{(i)}\}_{i=1}^N \sim q(x) { x(i)}i=1Nq(x) ),最大化对数似然:

L MLE ( θ ) = E q ( x ) [ log ⁡ p 1 ( x ; θ ) ] ≈ 1 N ∑ i = 1 N log ⁡ p 1 ( x ( i ) ; θ ) \mathcal{L}_{\text{MLE}}(\theta) = \mathbb{E}_{q(x)} [\log p_1(x; \theta)] \approx \frac{1}{N} \sum_{i=1}^N \log p_1(x^{(i)}; \theta) LMLE(θ)=Eq(x)[logp1(x;θ)]N1i=1Nlogp1(x(i);θ)

为了计算 ( p 1 ( x ; θ ) p_1(x; \theta) p1(x;θ) ),需要知道从 ( p 0 ( x ) p_0(x) p0(x) ) 到 ( p 1 ( x ) p_1(x) p1(x) ) 的概率密度演化,这依赖于流的变化公式。

概率密度演化与 ODE

根据 CNF 的推前公式(push-forward)和连续性方程,概率密度 ( p t ( x ) p_t(x) pt(x) ) 的演化满足:

p t ( x ) = p 0 ( ϕ t − 1 ( x ) ) det ⁡ ( ∂ ϕ t − 1 ∂ x ) p_t(x) = p_0(\phi_t^{-1}(x)) \det\left( \frac{\partial \phi_t^{-1}}{\partial x} \right) pt(x)=p0(ϕt1(x))det(xϕt1)

或者等价地,使用雅可比行列式的对数:

log ⁡ p t ( x ) = log ⁡ p 0 ( ϕ t − 1 ( x ) ) + log ⁡ det ⁡ ( ∂ ϕ t − 1 ∂ x ) \log p_t(x) = \log p_0(\phi_t^{-1}(x)) + \log \det\left( \frac{\partial \phi_t^{-1}}{\partial x} \right) logpt(x)=logp0(ϕt1(x))+logdet(xϕt1)

在 ( t = 1 t=1 t=1 ) 时:

log ⁡ p 1 ( x ) = log ⁡ p 0 ( ϕ 1 − 1 ( x ) ) + log ⁡ det ⁡ ( ∂ ϕ 1 − 1 ∂ x ) \log p_1(x) = \log p_0(\phi_1^{-1}(x)) + \log \det\left( \frac{\partial \phi_1^{-1}}{\partial x} \right) logp1(x)=logp0(ϕ11(x))+logdet(xϕ11)

计算过程
  1. 求解 ODE

    • 从初始点 ( x 1 = x x_1 = x x1=x )(数据样本),需要计算 ( ϕ 1 − 1 ( x ) \phi_1^{-1}(x) ϕ11(x) ),即逆流回到 ( t = 0 t=0 t=0 ) 的位置。
    • 这要求解逆向 ODE:
      d d t ϕ t ( x ) = − v t ( ϕ t ( x ) ; θ ) , ϕ 1 ( x ) = x , t : 1 → 0 \frac{d}{dt} \phi_t(x) = -v_t(\phi_t(x); \theta), \quad \phi_1(x) = x, \quad t: 1 \to 0 dtdϕt(x)=vt(ϕt(x);θ),ϕ1(x)=x,t:10
    • 使用数值 ODE 求解器(如 Euler 方法或 Runge-Kutta),从 ( x x x ) 积分到 ( ϕ 0 ( x ) \phi_0(x) ϕ0(x) )。
  2. 计算雅可比行列式

    • ( log ⁡ det ⁡ ( ∂ ϕ 1 − 1 ∂ x ) \log \det\left( \frac{\partial \phi_1^{-1}}{\partial x} \right) logdet(xϕ11) ) 是流变换的体积变化,需要计算雅可比矩阵 ( ∂ ϕ t ∂ x \frac{\partial \phi_t}{\partial x} xϕt ) 的行列式。
    • 直接计算雅可比矩阵在高维中成本极高,因此通常用伴随方法(adjoint method)或 Hutchinson 迹估计器计算其对数:
      d d t log ⁡ det ⁡ ( ∂ ϕ t ∂ x ) = tr ⁡ ( ∂ v t ∂ x ( ϕ t ( x ) ) ) \frac{d}{dt} \log \det\left( \frac{\partial \phi_t}{\partial x} \right) = \operatorname{tr}\left( \frac{\partial v_t}{\partial x} (\phi_t(x)) \right) dtdlogdet(xϕt)=tr(xvt(ϕt(x)))
    • 这需要额外解一个伴随 ODE:
      d d t z t = tr ⁡ ( ∂ v t ∂ x ( ϕ t ( x ) ) ) , z 0 = 0 , z 1 = log ⁡ det ⁡ ( ∂ ϕ 1 ∂ x ) \frac{d}{dt} z_t = \operatorname{tr}\left( \frac{\partial v_t}{\partial x} (\phi_t(x)) \right), \quad z_0 = 0, \quad z_1 = \log \det\left( \frac{\partial \phi_1}{\partial x} \right) dtdzt=tr(xvt(ϕt(x))),z0=0,z1=logdet(xϕ1)
  3. 组合结果

    • 将 ( ϕ 0 ( x ) \phi_0(x) ϕ0(x) ) 代入 ( log ⁡ p 0 ( ϕ 0 ( x ) ) \log p_0(\phi_0(x)) logp0(ϕ0(x)) ),加上 ( z 1 z_1 z1 ),得到 ( log ⁡ p 1 ( x ) \log p_1(x) logp1(x) )。

为什么需要昂贵的 ODE 模拟?

  • 多次积分:每次计算 ( p 1 ( x ) p_1(x) p1(x) ) 需要从 ( t = 1 t=1 t=1 ) 到 ( t = 0 t=0 t=0 ) 数值求解 ODE,步数取决于精度要求(例如 100 步)。
  • 雅可比计算:伴随 ODE 或迹估计需要额外的计算,尤其在高维数据中(例如图像),每次前向传播和梯度计算都涉及大量矩阵操作。
  • 批量训练:对每个样本 ( x ( i ) x^{(i)} x(i) ) 都要重复上述过程,导致总成本随样本数和维度指数级增长。

例子:对于一张 28×28 的 MNIST 图像(784 维),每次似然计算可能需要数百步 ODE 模拟,每步涉及神经网络前向传播和雅可比迹估计,计算量巨大。


Flow Matching 的“无仿真”方法

具体可以可以参考笔者的另外的博客:
深入解析 Flow Matching:从条件概率路径与向量场到条件流匹配

深入解析 Flow Matching(二):从条件概率路径与向量场到条件流匹配

原理

Flow Matching(FM)提出了一种替代方法,避免直接计算 ( p 1 ( x ) p_1(x) p1(x) ) 和昂贵的 ODE 模拟。其核心目标仍是训练 ( v t ( x ; θ ) v_t(x; \theta) vt(x;θ) ) 使 ( p 1 ( x ) ≈ q ( x ) p_1(x) \approx q(x) p1(x)q(x) ),但它通过回归目标向量场实现“无仿真”训练。

FM 的损失函数为:

L FM ( θ ) = E t , p t ( x ) ∥ v t ( x ; θ ) − u t ( x ) ∥ 2 \mathcal{L}_{\text{FM}}(\theta) = \mathbb{E}_{t, p_t(x)} \left\| v_t(x; \theta) - u_t(x) \right\|^2 LFM(θ)=Et,pt(x)vt(x;θ)ut(x)2

  • ( u t ( x ) u_t(x) ut(x) ):生成目标概率路径 ( p t ( x ) p_t(x) pt(x) ) 的真实向量场。
  • ( p t ( x ) = ∫ p t ( x ∣ x 1 ) q ( x 1 )   d x 1 p_t(x) = \int p_t(x | x_1) q(x_1) \, dx_1 pt(x)=pt(xx1)q(x1)dx1 ):边缘概率路径。

由于直接计算 ( u t ( x ) u_t(x) ut(x) ) 和 ( p t ( x ) p_t(x) pt(x) ) 不可行,FM 引入条件流匹配(CFM):

L CFM ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∥ v t ( x ; θ ) − u t ( x ∣ x 1 ) ∥ 2 \mathcal{L}_{\text{CFM}}(\theta) = \mathbb{E}_{t, q(x_1), p_t(x | x_1)} \left\| v_t(x; \theta) - u_t(x | x_1) \right\|^2 LCFM(θ)=Et,q(x1),pt(xx1)vt(x;θ)ut(xx1)2

“无仿真”的实现

  1. 条件路径与向量场

    • 定义一个已知的条件概率路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(xx1) ),例如:
      p t ( x ∣ x 1 ) = N ( x ∣ t x 1 , ( 1 − t ) 2 I ) p_t(x | x_1) = \mathcal{N}(x | t x_1, (1-t)^2 I) pt(xx1)=N(xtx1,(1t)2I)
    • 对应的条件向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ) 可以解析计算,例如:
      u t ( x ∣ x 1 ) = x 1 − x 1 − t u_t(x | x_1) = \frac{x_1 - x}{1-t} ut(xx1)=1tx1x
    • 这里不需要模拟整个流,只需根据 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(xx1) ) 的定义直接得出 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) )。
  2. 采样与回归

    • 采样 ( t ∼ U [ 0 , 1 ] t \sim \mathcal{U}[0,1] tU[0,1] )、( x 1 ∼ q ( x 1 ) x_1 \sim q(x_1) x1q(x1) )、( x ∼ p t ( x ∣ x 1 ) x \sim p_t(x | x_1) xpt(xx1) )。
    • 对于每个采样点,直接计算 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ),然后优化 ( v t ( x ; θ ) v_t(x; \theta) vt(x;θ) ) 使其接近 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) )。
    • 无需解 ODE,因为 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ) 是解析形式,损失计算仅需一次前向传播。
  3. 生成过程

    • 训练完成后,从 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0N(0,I) ) 开始,解正向 ODE:
      d d t x t = v t ( x t ; θ ) , t : 0 → 1 \frac{d}{dt} x_t = v_t(x_t; \theta), \quad t: 0 \to 1 dtdxt=vt(xt;θ),t:01
    • 这部分需要 ODE 求解,但仅在推理阶段,与训练无关。

为什么是“无仿真”?

  • 训练时无 ODE:传统 CNF 需要在训练时反复解 ODE 计算 ( p 1 ( x ) p_1(x) p1(x) ) 和雅可比行列式,而 FM 通过条件路径的解析形式直接提供 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ),无需模拟流的过程。
  • 计算效率:每次迭代只涉及采样和一次神经网络前向传播,成本远低于 ODE 模拟。

例子:对于 MNIST 图像,FM 训练时只需采样 ( t t t )、( x 1 x_1 x1 )、( x x x ),计算 ( u t ( x ∣ x 1 ) = x 1 − x 1 − t u_t(x | x_1) = \frac{x_1 - x}{1-t} ut(xx1)=1tx1x ),然后优化损失,无需数百步积分。


对比与总结

方面 传统 CNFs(MLE) Flow Matching(CFM)
训练目标 最大化 ( log ⁡ p 1 ( x ) \log p_1(x) logp1(x) ) 回归 ( v t ( x ) v_t(x) vt(x) ) 到 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) )
概率密度计算 通过 ODE 模拟 ( ϕ t \phi_t ϕt ) 和雅可比行列式 通过条件路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(xx1) ) 解析定义
ODE 需求 训练时需反复解 ODE 训练时无需 ODE,仅推理时用
计算成本 高(多步积分 + 雅可比估计) 低(单次前向传播)
灵活性 通用但复杂 依赖条件路径设计

传统 CNFs 的 ODE 模拟过程

  • 步骤:从数据 ( x x x ) 逆向积分到 ( ϕ 0 ( x ) \phi_0(x) ϕ0(x) ),计算概率密度变化,涉及多次 ODE 求解和迹估计。
  • 昂贵原因:高维数据下,积分步数多,伴随计算复杂。

FM 的无仿真方法

  • 步骤:定义解析路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(xx1) ),直接回归 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(xx1) ),训练时无需模拟。
  • 高效原因:跳过了 ODE 和雅可比计算,损失直接基于向量场差异。

希望这个详细解释能帮你理解传统 CNF 和 Flow Matching 在训练上的本质区别!

后记

2025年4月8日14点24分于上海,在grok 3大模型辅助下完成。