解释传统连续归一化流(Continuous Normalizing Flows, CNFs)训练方法(以最大似然估计为例)为什么需要昂贵的 ODE 模拟,以及 Flow Matching(FM)提出的“无仿真”(simulation-free)方法是如何绕过这一问题的。我们会从原理、公式和计算过程逐步展开。
传统 CNFs 的训练方法:最大似然估计与 ODE 模拟
原理
连续归一化流(CNFs)通过一个时间依赖的向量场 ( v t ( x ; θ ) v_t(x; \theta) vt(x;θ) )(由神经网络参数化)定义一个从简单分布(如标准正态分布 ( p 0 ( x ) = N ( 0 , I ) p_0(x) = \mathcal{N}(0, I) p0(x)=N(0,I) ))到目标数据分布 ( q ( x ) q(x) q(x) ) 的连续变换。这个变换由以下常微分方程(ODE)描述:
d d t ϕ t ( x ) = v t ( ϕ t ( x ) ; θ ) , ϕ 0 ( x ) = x \frac{d}{dt} \phi_t(x) = v_t(\phi_t(x); \theta), \quad \phi_0(x) = x dtdϕt(x)=vt(ϕt(x);θ),ϕ0(x)=x
- ( ϕ t ( x ) \phi_t(x) ϕt(x) ):流映射,从 ( t = 0 t=0 t=0 ) 到 ( t = 1 t=1 t=1 ) 将 ( x x x) 从初始点变换到目标点。
- ( p t ( x ) = [ ϕ t ] ∗ p 0 ( x ) p_t(x) = [\phi_t]_* p_0(x) pt(x)=[ϕt]∗p0(x) ):概率密度随时间 ( t t t ) 的演化。
最终,( p 1 ( x ) p_1(x) p1(x) ) 应逼近目标分布 ( q ( x ) q(x) q(x) )。
传统 CNFs 的训练目标是通过最大似然估计(Maximum Likelihood Estimation, MLE)优化参数 ( θ \theta θ ),使模型生成的 ( p 1 ( x ) p_1(x) p1(x) ) 与数据分布 ( q ( x ) q(x) q(x) ) 尽可能接近。
最大似然估计的目标
给定数据样本 ( { x ( i ) } i = 1 N ∼ q ( x ) \{x^{(i)}\}_{i=1}^N \sim q(x) { x(i)}i=1N∼q(x) ),最大化对数似然:
L MLE ( θ ) = E q ( x ) [ log p 1 ( x ; θ ) ] ≈ 1 N ∑ i = 1 N log p 1 ( x ( i ) ; θ ) \mathcal{L}_{\text{MLE}}(\theta) = \mathbb{E}_{q(x)} [\log p_1(x; \theta)] \approx \frac{1}{N} \sum_{i=1}^N \log p_1(x^{(i)}; \theta) LMLE(θ)=Eq(x)[logp1(x;θ)]≈N1i=1∑Nlogp1(x(i);θ)
为了计算 ( p 1 ( x ; θ ) p_1(x; \theta) p1(x;θ) ),需要知道从 ( p 0 ( x ) p_0(x) p0(x) ) 到 ( p 1 ( x ) p_1(x) p1(x) ) 的概率密度演化,这依赖于流的变化公式。
概率密度演化与 ODE
根据 CNF 的推前公式(push-forward)和连续性方程,概率密度 ( p t ( x ) p_t(x) pt(x) ) 的演化满足:
p t ( x ) = p 0 ( ϕ t − 1 ( x ) ) det ( ∂ ϕ t − 1 ∂ x ) p_t(x) = p_0(\phi_t^{-1}(x)) \det\left( \frac{\partial \phi_t^{-1}}{\partial x} \right) pt(x)=p0(ϕt−1(x))det(∂x∂ϕt−1)
或者等价地,使用雅可比行列式的对数:
log p t ( x ) = log p 0 ( ϕ t − 1 ( x ) ) + log det ( ∂ ϕ t − 1 ∂ x ) \log p_t(x) = \log p_0(\phi_t^{-1}(x)) + \log \det\left( \frac{\partial \phi_t^{-1}}{\partial x} \right) logpt(x)=logp0(ϕt−1(x))+logdet(∂x∂ϕt−1)
在 ( t = 1 t=1 t=1 ) 时:
log p 1 ( x ) = log p 0 ( ϕ 1 − 1 ( x ) ) + log det ( ∂ ϕ 1 − 1 ∂ x ) \log p_1(x) = \log p_0(\phi_1^{-1}(x)) + \log \det\left( \frac{\partial \phi_1^{-1}}{\partial x} \right) logp1(x)=logp0(ϕ1−1(x))+logdet(∂x∂ϕ1−1)
计算过程
-
求解 ODE:
- 从初始点 ( x 1 = x x_1 = x x1=x )(数据样本),需要计算 ( ϕ 1 − 1 ( x ) \phi_1^{-1}(x) ϕ1−1(x) ),即逆流回到 ( t = 0 t=0 t=0 ) 的位置。
- 这要求解逆向 ODE:
d d t ϕ t ( x ) = − v t ( ϕ t ( x ) ; θ ) , ϕ 1 ( x ) = x , t : 1 → 0 \frac{d}{dt} \phi_t(x) = -v_t(\phi_t(x); \theta), \quad \phi_1(x) = x, \quad t: 1 \to 0 dtdϕt(x)=−vt(ϕt(x);θ),ϕ1(x)=x,t:1→0 - 使用数值 ODE 求解器(如 Euler 方法或 Runge-Kutta),从 ( x x x ) 积分到 ( ϕ 0 ( x ) \phi_0(x) ϕ0(x) )。
-
计算雅可比行列式:
- ( log det ( ∂ ϕ 1 − 1 ∂ x ) \log \det\left( \frac{\partial \phi_1^{-1}}{\partial x} \right) logdet(∂x∂ϕ1−1) ) 是流变换的体积变化,需要计算雅可比矩阵 ( ∂ ϕ t ∂ x \frac{\partial \phi_t}{\partial x} ∂x∂ϕt ) 的行列式。
- 直接计算雅可比矩阵在高维中成本极高,因此通常用伴随方法(adjoint method)或 Hutchinson 迹估计器计算其对数:
d d t log det ( ∂ ϕ t ∂ x ) = tr ( ∂ v t ∂ x ( ϕ t ( x ) ) ) \frac{d}{dt} \log \det\left( \frac{\partial \phi_t}{\partial x} \right) = \operatorname{tr}\left( \frac{\partial v_t}{\partial x} (\phi_t(x)) \right) dtdlogdet(∂x∂ϕt)=tr(∂x∂vt(ϕt(x))) - 这需要额外解一个伴随 ODE:
d d t z t = tr ( ∂ v t ∂ x ( ϕ t ( x ) ) ) , z 0 = 0 , z 1 = log det ( ∂ ϕ 1 ∂ x ) \frac{d}{dt} z_t = \operatorname{tr}\left( \frac{\partial v_t}{\partial x} (\phi_t(x)) \right), \quad z_0 = 0, \quad z_1 = \log \det\left( \frac{\partial \phi_1}{\partial x} \right) dtdzt=tr(∂x∂vt(ϕt(x))),z0=0,z1=logdet(∂x∂ϕ1)
-
组合结果:
- 将 ( ϕ 0 ( x ) \phi_0(x) ϕ0(x) ) 代入 ( log p 0 ( ϕ 0 ( x ) ) \log p_0(\phi_0(x)) logp0(ϕ0(x)) ),加上 ( z 1 z_1 z1 ),得到 ( log p 1 ( x ) \log p_1(x) logp1(x) )。
为什么需要昂贵的 ODE 模拟?
- 多次积分:每次计算 ( p 1 ( x ) p_1(x) p1(x) ) 需要从 ( t = 1 t=1 t=1 ) 到 ( t = 0 t=0 t=0 ) 数值求解 ODE,步数取决于精度要求(例如 100 步)。
- 雅可比计算:伴随 ODE 或迹估计需要额外的计算,尤其在高维数据中(例如图像),每次前向传播和梯度计算都涉及大量矩阵操作。
- 批量训练:对每个样本 ( x ( i ) x^{(i)} x(i) ) 都要重复上述过程,导致总成本随样本数和维度指数级增长。
例子:对于一张 28×28 的 MNIST 图像(784 维),每次似然计算可能需要数百步 ODE 模拟,每步涉及神经网络前向传播和雅可比迹估计,计算量巨大。
Flow Matching 的“无仿真”方法
具体可以可以参考笔者的另外的博客:
深入解析 Flow Matching:从条件概率路径与向量场到条件流匹配
和
深入解析 Flow Matching(二):从条件概率路径与向量场到条件流匹配
原理
Flow Matching(FM)提出了一种替代方法,避免直接计算 ( p 1 ( x ) p_1(x) p1(x) ) 和昂贵的 ODE 模拟。其核心目标仍是训练 ( v t ( x ; θ ) v_t(x; \theta) vt(x;θ) ) 使 ( p 1 ( x ) ≈ q ( x ) p_1(x) \approx q(x) p1(x)≈q(x) ),但它通过回归目标向量场实现“无仿真”训练。
FM 的损失函数为:
L FM ( θ ) = E t , p t ( x ) ∥ v t ( x ; θ ) − u t ( x ) ∥ 2 \mathcal{L}_{\text{FM}}(\theta) = \mathbb{E}_{t, p_t(x)} \left\| v_t(x; \theta) - u_t(x) \right\|^2 LFM(θ)=Et,pt(x)∥vt(x;θ)−ut(x)∥2
- ( u t ( x ) u_t(x) ut(x) ):生成目标概率路径 ( p t ( x ) p_t(x) pt(x) ) 的真实向量场。
- ( p t ( x ) = ∫ p t ( x ∣ x 1 ) q ( x 1 ) d x 1 p_t(x) = \int p_t(x | x_1) q(x_1) \, dx_1 pt(x)=∫pt(x∣x1)q(x1)dx1 ):边缘概率路径。
由于直接计算 ( u t ( x ) u_t(x) ut(x) ) 和 ( p t ( x ) p_t(x) pt(x) ) 不可行,FM 引入条件流匹配(CFM):
L CFM ( θ ) = E t , q ( x 1 ) , p t ( x ∣ x 1 ) ∥ v t ( x ; θ ) − u t ( x ∣ x 1 ) ∥ 2 \mathcal{L}_{\text{CFM}}(\theta) = \mathbb{E}_{t, q(x_1), p_t(x | x_1)} \left\| v_t(x; \theta) - u_t(x | x_1) \right\|^2 LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x;θ)−ut(x∣x1)∥2
“无仿真”的实现
-
条件路径与向量场:
- 定义一个已知的条件概率路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) ),例如:
p t ( x ∣ x 1 ) = N ( x ∣ t x 1 , ( 1 − t ) 2 I ) p_t(x | x_1) = \mathcal{N}(x | t x_1, (1-t)^2 I) pt(x∣x1)=N(x∣tx1,(1−t)2I) - 对应的条件向量场 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) 可以解析计算,例如:
u t ( x ∣ x 1 ) = x 1 − x 1 − t u_t(x | x_1) = \frac{x_1 - x}{1-t} ut(x∣x1)=1−tx1−x - 这里不需要模拟整个流,只需根据 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) ) 的定义直接得出 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) )。
- 定义一个已知的条件概率路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) ),例如:
-
采样与回归:
- 采样 ( t ∼ U [ 0 , 1 ] t \sim \mathcal{U}[0,1] t∼U[0,1] )、( x 1 ∼ q ( x 1 ) x_1 \sim q(x_1) x1∼q(x1) )、( x ∼ p t ( x ∣ x 1 ) x \sim p_t(x | x_1) x∼pt(x∣x1) )。
- 对于每个采样点,直接计算 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ),然后优化 ( v t ( x ; θ ) v_t(x; \theta) vt(x;θ) ) 使其接近 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) )。
- 无需解 ODE,因为 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) 是解析形式,损失计算仅需一次前向传播。
-
生成过程:
- 训练完成后,从 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0∼N(0,I) ) 开始,解正向 ODE:
d d t x t = v t ( x t ; θ ) , t : 0 → 1 \frac{d}{dt} x_t = v_t(x_t; \theta), \quad t: 0 \to 1 dtdxt=vt(xt;θ),t:0→1 - 这部分需要 ODE 求解,但仅在推理阶段,与训练无关。
- 训练完成后,从 ( x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x0∼N(0,I) ) 开始,解正向 ODE:
为什么是“无仿真”?
- 训练时无 ODE:传统 CNF 需要在训练时反复解 ODE 计算 ( p 1 ( x ) p_1(x) p1(x) ) 和雅可比行列式,而 FM 通过条件路径的解析形式直接提供 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ),无需模拟流的过程。
- 计算效率:每次迭代只涉及采样和一次神经网络前向传播,成本远低于 ODE 模拟。
例子:对于 MNIST 图像,FM 训练时只需采样 ( t t t )、( x 1 x_1 x1 )、( x x x ),计算 ( u t ( x ∣ x 1 ) = x 1 − x 1 − t u_t(x | x_1) = \frac{x_1 - x}{1-t} ut(x∣x1)=1−tx1−x ),然后优化损失,无需数百步积分。
对比与总结
方面 | 传统 CNFs(MLE) | Flow Matching(CFM) |
---|---|---|
训练目标 | 最大化 ( log p 1 ( x ) \log p_1(x) logp1(x) ) | 回归 ( v t ( x ) v_t(x) vt(x) ) 到 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ) |
概率密度计算 | 通过 ODE 模拟 ( ϕ t \phi_t ϕt ) 和雅可比行列式 | 通过条件路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) ) 解析定义 |
ODE 需求 | 训练时需反复解 ODE | 训练时无需 ODE,仅推理时用 |
计算成本 | 高(多步积分 + 雅可比估计) | 低(单次前向传播) |
灵活性 | 通用但复杂 | 依赖条件路径设计 |
传统 CNFs 的 ODE 模拟过程
- 步骤:从数据 ( x x x ) 逆向积分到 ( ϕ 0 ( x ) \phi_0(x) ϕ0(x) ),计算概率密度变化,涉及多次 ODE 求解和迹估计。
- 昂贵原因:高维数据下,积分步数多,伴随计算复杂。
FM 的无仿真方法
- 步骤:定义解析路径 ( p t ( x ∣ x 1 ) p_t(x | x_1) pt(x∣x1) ),直接回归 ( u t ( x ∣ x 1 ) u_t(x | x_1) ut(x∣x1) ),训练时无需模拟。
- 高效原因:跳过了 ODE 和雅可比计算,损失直接基于向量场差异。
希望这个详细解释能帮你理解传统 CNF 和 Flow Matching 在训练上的本质区别!
后记
2025年4月8日14点24分于上海,在grok 3大模型辅助下完成。