近似推断 - 变分推断和学习篇

前言

在数据科学和机器学习的广阔领域中,变分推断(Variational Inference, VI)作为一种强大的近似推断方法,正逐渐崭露头角。随着大数据时代的到来,我们面临的数据集越来越庞大且复杂,传统的精确推断方法往往计算成本高昂,甚至在某些情况下变得不切实际。因此,变分推断以其高效、可扩展的特性,成为了处理大规模数据和复杂模型的首选方法。

序言

变分推断的核心思想在于,通过优化一个简单分布(称为变分分布)来近似复杂的后验分布。这种近似不仅降低了计算成本,还使得推断过程更加高效和灵活。同时,变分推断与机器学习中的学习算法紧密结合,通过迭代优化变分分布的参数,使得近似分布逐渐逼近真实的后验分布,从而实现参数的准确估计和模型的优化。

变分推断和学习

  • 我们已经说明过了:

    • (1)为什么证据下界 L ( v , θ , q ) \mathcal{L}(\boldsymbol{v},\boldsymbol{\theta},q) L(v,θ,q) log ⁡ p ( v ; θ ) \log p(\boldsymbol{v}; \boldsymbol{\theta}) logp(v;θ) 的一个下界?
    • (2)如何将推断看做是关于 q q q 分布最大化 L \mathcal{L} L 的过程?
    • (3)以及如何将学习看做是关于参数 θ \boldsymbol{\theta} θ 最大化 L \mathcal{L} L 的过程。
  • 我们也讲到了 EM \text{EM} EM算法在给定了 q q q 分布的条件下能够进行大学习步骤,而基于 MAP \text{MAP} MAP推断的学习算法则是学习一个 p ( h ∣ v ) p(\boldsymbol{h} \mid \boldsymbol{v}) p(hv) 的点估计而非推断整个完整的分布。在这里我们介绍一些变分学习中更加通用的算法。

  • 变分学习的核心思想就是我们通过选择给定有约束的分布族中一个 q q q 分布来最大化 L \mathcal{L} L。选择这个分布族时应该考虑到计算 E q log ⁡ p ( h , v ) \mathbb{E}_q \log p(\boldsymbol{h} ,\boldsymbol{v}) Eqlogp(h,v) 的简单性。一个典型的方法就是添加一些 q q q 分布如何分解的假设。

  • 一种常用的变分学习的方法是加入一些限制使得 q q q 是一个因子分布:
    q ( h ∣ v ) = ∏ i q ( h i ∣ v ) q(\boldsymbol{h}\mid\boldsymbol{v})=\prod\limits_i q(h_i\mid\boldsymbol{v}) q(hv)=iq(hiv) — 公式1 \quad\textbf{---\footnotesize{公式1}} 公式1

  • 这被叫做是均匀场 ( mean-field \text{mean-field} mean-field) 方法。更一般地说,我们可以通过选择 q q q 分布的形式来选择任何图模型的结构,通过选择变量之间的相互作用来灵活地决定近似程度的大小。这种完全通用的图模型方法叫做结构化变分推断 ( structured variational inference \text{structured variational inference} structured variational inference) ( Saul and Jordan, 1996 \text{Saul and Jordan, 1996} Saul and Jordan, 1996)。

  • 变分方法的优点是我们不需要为分布 q q q 设定一个特定的参数化形式。

    • 我们设定它如何分解,之后通过解决优化问题来找出在这些分解限制下最优的概率分布。
    • 对离散型潜变量来说,这意味着我们使用了传统的优化技巧来优化描述 q q q 分布的有限个数变量。
    • 对连续型潜变量来说,这意味着我们使用了一个叫做变分法的数学分支来解决对一个函数空间的优化问题。
    • 然后决定哪一个函数来表示 q q q 分布。
    • 变分法是‘‘变分学习’’ 或者 ‘‘变分推断’’ 这些名字的来历,尽管当潜变量是离散时变分法并没有用武之地。
    • 当遇到连续型潜变量时, 变分法不需要过多地人工选择模型,是一种很有用的工具。
    • 我们只需要设定分布 q q q 如何分解,而不需要去猜测一个特定的能够精确近似原后验分布的 q q q 分布。
  • 因为 L ( v , θ , q ) \mathcal{L}(\boldsymbol{v},\boldsymbol{\theta},q) L(v,θ,q) 定义成 log ⁡ p ( v ; θ ) − D KL ( q ( h ∣ v ) ∣ ∣ p ( h ∣ v ; θ ) ) \log p(\boldsymbol{v};\boldsymbol{\theta})-D_{\text{KL}}(q(\boldsymbol{h}\mid\boldsymbol{v})||p(\boldsymbol{h}\mid\boldsymbol{v};\boldsymbol{\theta})) logp(v;θ)DKL(q(hv)∣∣p(hv;θ)),我们可以认为关于 q q q 最大化 L \mathcal{L} L 的问题等价于(关于 q q q)最小化 D KL ( q ( h ∣ v ) ∣ ∣ p ( h ∣ v ) ) D_{\text{KL}}(q(\boldsymbol{h}\mid\boldsymbol{v})||p(\boldsymbol{h}\mid\boldsymbol{v})) DKL(q(hv)∣∣p(hv))

    • 在这种情况下,我们要用 q q q 来拟合 p p p
    • 然而,我们并不是直接拟合一个近似,而是处理一个 KL \text{KL} KL散度的问题。
    • 当我们使用最大似然学习来用模型拟合数据时,我们最小化 D KL ( p data ∣ ∣ p model ) D_{\text{KL}}(p_{\text{data}}||p_{\text{model}}) DKL(pdata∣∣pmodel)
    • 如同图例1中所示,这意味着最大似然促进模型在每一个数据达到更高概率的地方达到更高的概率,而基于优化的推断则促进了 q q q 在每一个真实后验分布概率较低的地方概率较小。
    • 这两种基于 KL \text{KL} KL散度的方法都有各自的优点与缺点。
    • 选择哪一种方法取决于在具体每一个应用中哪一种性质更受偏好。
    • 在基于优化的推断问题中,从计算角度考虑,我们选择使用 D KL ( q ( h ∣ v ) ∣ ∣ p ( h ∣ v ) ) D_{\text{KL}}(q(\boldsymbol{h}\mid\boldsymbol{v})||p(\boldsymbol{h}\mid\boldsymbol{v})) DKL(q(hv)∣∣p(hv))
    • 具体地说,计算 D KL ( q ( h ∣ v ) ∣ ∣ p ( h ∣ v ) ) D_{\text{KL}}(q(\boldsymbol{h}\mid\boldsymbol{v})||p(\boldsymbol{h}\mid\boldsymbol{v})) DKL(q(hv)∣∣p(hv)) 涉及到了计算 q q q 分布下的期望。
    • 所以通过将分布 q q q 设计得较为简单,我们可以简化求所需要的期望的计算过程。
    • 另一个方向的 KL \text{KL} KL 散度需要计算真实后验分布下的期望。因为真实后验分布的形式是由模型的选择决定的,我们不能设计出一种能够精确计算 D KL ( p ( h ∣ v ) ∣ ∣ q ( h ∣ v ) ) D_{\text{KL}}(p(\boldsymbol{h}\mid\boldsymbol{v})||q(\boldsymbol{h}\mid\boldsymbol{v})) DKL(p(hv)∣∣q(hv)) 的开销较小的方法。

  • 图例1: KL \text{KL} KL散度是不对称的。
    • KL \text{KL} KL散度是不对称的
      在这里插入图片描述

    • 说明:

      • 假设我们有一个分布 p ( x ) p(x) p(x),并且希望用另一个分布 q ( x ) q(x) q(x) 来近似它。
      • 我们可以选择最小化 D KL ( p ∣ ∣ q ) D_{\text{KL}}(p||q) DKL(p∣∣q) 或最小化 D KL ( q ∣ ∣ p ) D_{\text{KL}}(q||p) DKL(q∣∣p)
      • 为了说明每种选择的效果,我们令 p p p 是两个高斯分布的混合,令 q q q 为单个高斯分布。
      • 选择使用 KL {\text{KL}} KL 散度的哪个方向是取决于问题的。
      • 一些应用需要这个近似分布 q q q 在真实分布 p p p 放置高概率的所有地方都放置高概率,而其他应用需要这个近似分布 q q q 在真实分布 p p p 放置低概率的所有地方都很少放置高概率。
      • KL {\text{KL}} KL 散度方向的选择反映了对于每种应用,优先考虑哪一种选择。
      • 左图:
        • 最小化 D KL ( p ∣ ∣ q ) D_{\text{KL}}(p||q) DKL(p∣∣q) 的效果。
        • 在这种情况下,我们选择一个 q q q 使得它在 p p p 具有高概率的地方具有高概率。
        • p p p 具有多个峰时, q q q 选择将这些峰模糊到一起,以便将高概率质量放到所有峰上。
      • 右图:
        • 最小化 D KL ( q ∣ ∣ p ) D_{\text{KL}}(q||p) DKL(q∣∣p) 的效果。
        • 在这种情况下,我们选择一个 q q q 使得它在 p p p 具有低概率的地方具有低概率。
        • p p p 具有多个峰并且这些峰间隔很宽时,如该图所示,最小化 KL {\text{KL}} KL 散度会选择单个峰,以避免将概率质量放置在 p p p 的多个峰之间的低概率区域中。
        • 这里,我们说明当 q q q 被选择成强调左边峰时的结果。
        • 我们也可以通过选择右边峰来得到 KL {\text{KL}} KL 散度相同的值。
        • 如果这些峰没有被足够强的低概率区域分离,那么 KL {\text{KL}} KL 散度的这个方向仍然可能选择模糊这些峰。

离散型潜变量

  • 关于离散型潜变量的变分推断相对来说比较直接。

    • 我们定义一个分布 q q q,通常 q q q 的每个因子都由一些离散状态的可查询表格定义。
    • 在最简单的情况中, h \boldsymbol{h} h 是二元的并且我们做了均匀场假定, q q q 可以根据每一个 h i h_i hi 分解。
    • 在这种情况下,我们可以用一个向量 h ^ \hat{\boldsymbol{h}} h^ 来参数化 q q q 分布, h ^ \hat{\boldsymbol{h}} h^ 的每一个元素都代表一个概率,即 q ( h i = 1 ∣ v ) = h ^ i q(h_i = 1 \mid \boldsymbol{v}) = \hat{h}_i q(hi=1v)=h^i
  • 在确定了如何表示 q q q 以后,我们只需要优化它的参数。在离散型潜变量模型中,这是一个标准的优化问题。基本上 q q q 的选择可以通过任何优化算法解决,比如说梯度下降。

  • 这个优化问题是很高效的因为它在许多学习算法的内循环中出现。

    • 为了追求速度,我们通常使用特殊设计的优化算法。
    • 这些算法通常能够在极少的循环内解决一些小而简单的问题。
    • 一个常见的选择是使用不动点方程,换句话说,就是解关于 h ^ i \hat{h}_i h^i的方程: ∂ ∂ h ^ i L = 0 \displaystyle\frac{\partial}{\partial\hat{h}_i}\mathcal{L}=0 h^iL=0 — 公式2 \quad\textbf{---\footnotesize{公式2}} 公式2
      我们反复地更新 h ^ \hat{\boldsymbol{h}} h^ 不同的元素直到收敛准则满足。
  • 为了具体化这些描述,我们接下来会讲如何将变分推断应用到二值稀疏编码 ( binary sparse coding \text{binary sparse coding} binary sparse coding) 模型(这里我们所描述的模型是 Henniges et al. (2010) \text{Henniges et al. (2010)} Henniges et al. (2010) 提出的,但是我们采用了传统、通用的均匀场方法,而原文作者采用了一种特殊设计的算法)中。

    • 推导过程在数学上非常详细,为希望完全了解我们描述过的变分推断和学习高级概念描述的读者所准备。
    • 而对于并不计划推导或者实现变分学习算法的读者来说,可以放心跳过,直接阅读变分法,这并不会导致新的高级概念的遗漏。
    • 建议那些从事二值稀疏编码研究的读者可以重新看一下应用数学与机器学习基础 - 概率与信息论篇 - 常用函数的一些性质中描述的一些经常在概率模型中出现的有用的函数性质。
    • 我们在推导过程中随意地使用了这些性质,并没有特别强调它们。
  • 在二值稀疏编码模型中,输入 v ∈ R n \boldsymbol{v} \in \mathbb{R}^n vRn,是由模型通过添加高斯噪音到 m m m 个或有或无的成分。

    • 每一个成分可以是开或者关的,对应着隐藏单元 h ∈ { 0 , 1 } m \boldsymbol{h}\in\{0,1\}^m h{ 0,1}m
      p ( h i = 1 ) = σ ( b i ) p(h_i=1)=\sigma(b_i) p(hi=1)=σ(bi) — 公式3 \quad\textbf{---\footnotesize{公式3}} 公式3
      p ( v ∣ h ) = N ( v ; W h , β ( − 1 ) ) p(\boldsymbol{v}\mid\boldsymbol{h})=\mathcal{N}(\boldsymbol{v};\boldsymbol{Wh},\beta^{(-1)}) p(vh)=N(v;Wh,β(1)) — 公式4 \quad\textbf{---\footnotesize{公式4}} 公式4
    • 其中 b \boldsymbol{b} b 是一个可以学习的偏置集合, W \boldsymbol{W} W 是一个可以学习的权值矩阵, β \beta β 是一个可以学习的对角精度矩阵。
  • 使用最大似然来训练这样一个模型需要对参数进行求导。我们考虑对其中一个偏置进行求导的过程:
    { ∂ ∂ b i log ⁡ p ( v ) — 公式5 = ∂ ∂ b i p ( v ) p ( v ) — 公式6 = ∂ ∂ b i ∑ h p ( h ) p ( v ∣ h ) p ( v ) — 公式7 = ∑ h p ( v ∣ h ) ∂ ∂ b i p ( h ) p ( v ) — 公式8 = ∑ h p ( h ∣ v ) ∂ ∂ b i p ( h ) p ( h ) — 公式9 = E h ∼ p ( h ∣ v ) ∂ ∂ b i log ⁡ p ( h ) — 公式10 \begin{cases} \begin{aligned} &\quad \frac{\partial}{\partial b_i} \log p(\boldsymbol{v}) &\quad\textbf{---\footnotesize{公式5}}\\\\ &=\frac{\frac{\partial}{\partial b_i}p(\boldsymbol{v})}{p(\boldsymbol{v})} &\quad\textbf{---\footnotesize{公式6}}\\\\ &=\frac{\frac{\partial}{\partial b_i}\sum_h p(\boldsymbol{h})p(\boldsymbol{v}\mid\boldsymbol{h})}{p(\boldsymbol{v})} &\quad\textbf{---\footnotesize{公式7}}\\\\ &=\frac{\sum_h p(\boldsymbol{v}\mid\boldsymbol{h})\frac{\partial}{\partial b_i}p(\boldsymbol{h})}{p(\boldsymbol{v})} &\quad\textbf{---\footnotesize{公式8}}\\\\ &=\sum_h p(\boldsymbol{h}\mid\boldsymbol{v})\frac{\frac{\partial}{\partial b_i}p(\boldsymbol{h})}{p(\boldsymbol{h})}&\quad\textbf{---\footnotesize{公式9}}\\\\ &=\mathbb{E}_{\boldsymbol{h}\sim p(\boldsymbol{h}\mid\boldsymbol{v})} \frac{\partial}{\partial b_i} \log p(\boldsymbol{h}) &\quad\textbf{---\footnotesize{公式10}} \end{aligned} \end{cases} bilogp(v)=p(v)bip(v)=p(v)bihp(h)p(vh)=p(v)hp(vh)bip(h)=hp(hv)p(h)bip(h)=Ehp(hv)bilogp(h)公式5公式6公式7公式8公式9公式10

  • 这需要计算 p ( h ∣ v ) p(\boldsymbol{h} \mid \boldsymbol{v}) p(hv) 下的期望。

    • 不幸的是, p ( h ∣ v ) p(\boldsymbol{h} \mid \boldsymbol{v}) p(hv) 是一个很复杂的分布。
    • p ( h , v ) p(\boldsymbol{h}, \boldsymbol{v}) p(h,v) p ( h ∣ v ) p(\boldsymbol{h} \mid \boldsymbol{v}) p(hv) 的图结构见图例2
    • 隐藏单元的后验分布对应的是关于隐藏单元的完全图,所以相对于暴力算法,消元算法并不能有助于提高计算所需要的期望的效率。

  • 图例2:包含四个隐藏单元的二值稀疏编码的图结构。

    • 包含四个隐藏单元的二值稀疏编码的图结构
      在这里插入图片描述

    • 说明:

      • 左图:
        • p ( h , v ) p(\boldsymbol{h}, \boldsymbol{v}) p(h,v) 的图结构。
        • 要注意边是有向的,每两个隐藏单元都是每个可见单元的 coparent \text{coparent} coparent
      • 右图:
        • p ( h , v ) p(\boldsymbol{h}, \boldsymbol{v}) p(h,v)的图结构。
        • 为了解释 coparent \text{coparent} coparent之间的活跃路径,后验分布所有隐藏单元之间都有边。
  • 取而代之的是,我们可以应用变分推断和变分学习来解决这个难点。

  • 我们可以做一个均匀场近似:
    q ( h ∣ v ) = ∏ q ( h i ∣ v ) q(\boldsymbol{h}\mid\boldsymbol{v})=\prod q(h_i\mid\boldsymbol{v}) q(hv)=q(hiv) — 公式11 \quad\textbf{---\footnotesize{公式11}} 公式11

  • 二值稀疏编码中的潜变量是二值的,所以为了表示可分解的 q q q 我们假设对 m m m Bernoulli \text{Bernoulli} Bernoulli 分布 q ( h i ∣ v ) q(h_i \mid \boldsymbol{v}) q(hiv) 建模。

    • 表示 Bernoulli \text{Bernoulli} Bernoulli 分布的一种很自然的方法是使用一个概率向量 h ^ \hat{\boldsymbol{h}} h^,满足 q ( h i ∣ v ) = h ^ i q(h_i \mid \boldsymbol{v}) = \hat{h}^i q(hiv)=h^i
    • 为了避免计算中的误差,比如说计算 log ⁡ h ^ i \log \hat{h}^i logh^i 时,我们对 h ^ i \hat{h}^i h^i 添加一个约束,即 h ^ i \hat{h}^i h^i 不等于 0 0 0 或者 1 1 1
  • 我们将会看到变分推断方程理论上永远不会赋予 h ^ i \hat{h}^i h^i 0 0 0 或者 1 1 1

    • 然而在软件实现过程中,机器的舍入误差会导致 0 0 0 或者 1 1 1 的值。
    • 在二值稀疏编码的软件实现中,我们希望使用一个没有限制的变分参数向量 z z z 以及通过关系 h ^ = σ ( z ) \hat{\boldsymbol{h}} = \sigma(z) h^=σ(z) 来获得 h \boldsymbol{h} h
    • 因此我们可以放心地在计算机上计算 log ⁡ h ^ i \log \hat{h}^i logh^i 通过使用关系式 log ⁡ σ ( z i ) = − ζ ( − z i ) \log\sigma(z_i) = −\zeta(−z_i) logσ(zi)=ζ(zi) 来建立 sigmoid \text{sigmoid} sigmoid函数和 softplus \text{softplus} softplus函数的关系。
  • 在开始二值稀疏编码模型中变分学习的推导时,我们首先说明了均匀场近似的使用可以使得学习过程更加简单。

  • 证据下界可以表示为:
    { L ( v , θ , q ) — 公式12 = E h ∼ q [ log ⁡ p ( h , v ) ] + H ( q ) — 公式13 = E h ∼ q [ log ⁡ p ( h ) + log ⁡ p ( v ∣ h ) − log ⁡ q ( h ∣ v ) ] — 公式14 = E h ∼ q [ ∑ i = 1 m log ⁡ p ( h i ) + ∑ i = 1 m log ⁡ p ( v i ∣ h ) − ∑ i = 1 m log ⁡ q ( h i ∣ v ) ] — 公式15 = ∑ i = 1 m [ h ^ i ( log ⁡ σ ( b i ) − log ⁡ h ^ i ) + ( 1 − h ^ i ) ( log ⁡ σ ( − b i ) − log ⁡ ( 1 − h ^ i ) ) ] — 公式16 + E h ∼ q [ ∑ i = 1 n log ⁡ β i 2 π e ( − β i 2 ( v i − W i , : h ) 2 ) ] — 公式17 = ∑ i = 1 m [ h ^ i ( log ⁡ σ ( b i ) − log ⁡ h ^ i ) + ( 1 − h ^ i ) ( log ⁡ σ ( − b i ) − log ⁡ ( 1 − h ^ i ) ) ] — 公式18 + 1 2 ∑ i = 1 [ log ⁡ β i 2 π − β i ( v i 2 − 2 v i W i , : h ^ + ∑ j [ W i , j 2 h ^ j + ∑ k ≠ j W i , j W i , k h ^ j , h ^ k ] ) ] — 公式19 \begin{cases} \begin{aligned} &\quad\mathcal{L}(\boldsymbol{v},\boldsymbol{\theta},q) &\quad\textbf{---\footnotesize{公式12}}\\\\ &=\mathbb{E}_{\boldsymbol{h}\sim q}[\log p(\boldsymbol{h},\boldsymbol{v})]+H(q) &\quad\textbf{---\footnotesize{公式13}}\\\\ &=\mathbb{E}_{\boldsymbol{h}\sim q}[\log p(\boldsymbol{h})+\log p(\boldsymbol{v}\mid\boldsymbol{h})-\log q(\boldsymbol{h}\mid\boldsymbol{v})]&\quad\textbf{---\footnotesize{公式14}}\\\\ &=\mathbb{E}_{\boldsymbol{h}\sim q}[\sum\limits_{i=1}^m\log p(h_i)+\sum\limits_{i=1}^m\log p(v_i\mid\boldsymbol{h})-\sum\limits_{i=1}^m\log q(h_i\mid\boldsymbol{v})]&\quad\textbf{---\footnotesize{公式15}}\\\\ &=\sum\limits_{i=1}^m\left[\hat{h}_i(\log\sigma(b_i)-\log\hat{h}_i)+(1-\hat{h}_i)(\log\sigma(-b_i)-\log(1-\hat{h}_i))\right] &\quad\textbf{---\footnotesize{公式16}}\\\\ &\quad+\mathbb{E}_{\boldsymbol{h}\sim q}\left[\sum\limits_{i=1}^n\log\sqrt{\frac{\beta_i}{2\pi}}e^{\left(\displaystyle-\frac{\beta_i}{2}(v_i-\boldsymbol{W}_{i,:}\boldsymbol{h})^2\right)}\right]&\quad\textbf{---\footnotesize{公式17}}\\\\ &=\sum\limits_{i=1}^m\left[\hat{h}_i(\log\sigma(b_i)-\log\hat{h}_i)+(1-\hat{h}_i)(\log\sigma(-b_i)-\log(1-\hat{h}_i))\right]&\quad\textbf{---\footnotesize{公式18}}\\\\ &\quad+\frac{1}{2}\sum\limits_{i=1}\left[\log\frac{\beta_i}{2\pi}-\beta_i\left(v_i^2-2v_i\boldsymbol{W}_{i,:}\hat{\boldsymbol{h}}+\sum\limits_j\left[W_{i,j}^2\hat{h}_j+\sum\limits_{k\ne j}W_{i,j}W_{i,k}\hat{h}_j,\hat{h}_k\right]\right) \right]&\quad\textbf{---\footnotesize{公式19}}\\\\ \end{aligned} \end{cases} L(v,θ,q)=Ehq[logp(h,v)]+H(q)=Ehq[logp(h)+logp(vh)logq(hv)]=Ehq[i=1mlogp(hi)+i=1mlogp(vih)i=1mlogq(hiv)]=i=1m[h^i(logσ(bi)logh^i)+(1h^i)(logσ(bi)log(1h^i))]+Ehq i=1nlog2πβi e(2βi(viWi,:h)2) =i=1m[h^i(logσ(bi)logh^i)+(1h^i)(logσ(bi)log(1h^i))]+21i=1 log2πβiβi vi22viWi,:h^+j Wi,j2h^j+k=jWi,jWi,kh^j,h^k 公式12公式13公式14公式15公式16公式17公式18公式19
    — 公式19 \quad\textbf{---\footnotesize{公式19}} 公式19

  • 尽管这些方程从美学观点来看有些不尽如人意。

    • 他们展示了 L \mathcal{L} L 可以被表示为少量简单的代数运算。
    • 因此证据下界 L \mathcal{L} L是易于处理的。
    • 我们可以把 L \mathcal{L} L 看作是难以处理的对数似然函数的一个替代。
  • 原则上说,我们可以使用关于 v \boldsymbol{v} v h \boldsymbol{h} h 的梯度上升。

    • 这会成为一个完美的组合(推断算法和学习算法)的推断和学习算法。
    • 但是,由于两个原因,我们往往不这么做。
    • 第一点,对每一个 v \boldsymbol{v} v 我们需要存储 h ^ \hat{\boldsymbol{h}} h^
    • 我们通常更加偏向于那些不需要为每一个样本都准备内存的算法。
    • 如果我们需要为每一个样本都存储一个动态更新的向量,使得算法很难处理好几亿样本。
    • 第二个原因就是为了能够识别 v \boldsymbol{v} v 的内容,我们希望能够有能力快速提取特征 h ^ \hat{\boldsymbol{h}} h^
    • 在实际应用场景中,我们需要在有限时间内计算出 h ^ \hat{\boldsymbol{h}} h^
  • 由于以上两个原因,我们通常不会采用梯度下降来计算均匀场参数 h ^ \hat{\boldsymbol{h}} h^。取而代之的是,我们使用不动点方程来快速估计他们。

  • 不动点方程的核心思想是我们寻找一个关于 h \boldsymbol{h} h的局部极大点, 满足 ∇ h L ( v , θ , h ^ ) = 0 \nabla_h\mathcal{L}(\boldsymbol{v},\boldsymbol{\theta},\hat{\boldsymbol{h}})=0 hL(v,θ,h^)=0。我们无法同时高效地计算所有 h ^ \hat{\boldsymbol{h}} h^ 的元素。然而,我们可以解决单个变量的问题:
    ∂ ∂ h ^ i L ( v , θ , h ^ ) = 0 \frac{\partial}{\partial\hat{h}_i}\mathcal{L}(\boldsymbol{v},\boldsymbol{\theta},\hat{\boldsymbol{h}})=0 h^iL(v,θ,h^)=0 — 公式20 \quad\textbf{---\footnotesize{公式20}} 公式20

  • 我们可以迭代地将这个解应用到 i = 1 , … , m i = 1,\dots,m i=1,,m,然后重复这个循环直到我们满足了收敛准则。常见的收敛准则包含了当整个循环所改进的 L \mathcal{L} L 不超过预设的容差量时停止,或者是循环中改变的 h ^ \hat{\boldsymbol{h}} h^ 不超过某个值时停止。

  • 在很多不同的模型中,迭代的均匀场不动点方程是一种能够提供快速变分推断的通用算法。为了使它更加具体化,我们详细地讲一下如何推导出二值稀疏编码模型的更新过程。

  • 首先,我们给出了对 h ^ i \hat{h}_i h^i的导数表达式。为了得到这个表达式,我们将公式19代入公式20的左边:
    { ∂ ∂ h ^ i L ( v , θ , h ^ ) — 公式21 = ∂ ∂ h ^ i [ ∑ j = 1 m [ h ^ j ( log ⁡ σ ( b j ) − log ⁡ h ^ j ) + ( 1 − h ^ j ) ( log ⁡ σ ( − b j ) − log ⁡ ( 1 − h ^ j ) ) ] + 1 2 ∑ j = 1 n [ log ⁡ β j 2 π − β j ( v j 2 − 2 v j W j , : h ^ + ∑ k [ W j , k 2 h ^ k + ∑ l ≠ k W j , k W j , l h ^ k h ^ l ] ) ] ] — 公式22 = log ⁡ σ ( b i ) − log ⁡ h ^ i − 1 + log ⁡ ( 1 − h ^ i ) + 1 − log ⁡ σ ( − b i ) + ∑ j = 1 n [ β i ( v j W j , i − 1 2 W j , i 2 − ∑ k ≠ i W j , k W j , i h ^ k ) ] — 公式23 = b i − log ⁡ h ^ i + log ⁡ ( 1 − h ^ i ) + v ⊤ β W : , i − 1 2 W : , i ⊤ β W : , i − ∑ j ≠ i W : , j ⊤ β W : , i h ^ j — 公式24 \begin{cases} \begin{aligned} &\quad\frac{\partial}{\partial\hat{h}_i}\mathcal{L}(\boldsymbol{v},\boldsymbol{\theta},\hat{\boldsymbol{h}}) &\quad\textbf{---\footnotesize{公式21}}\\\\ &=\frac{\partial}{\partial\hat{h}_i}\left[\sum\limits_{j=1}^m\left[\hat{h}_j(\log\sigma(b_j)-\log\hat{h}_j)+(1-\hat{h}_j)(\log\sigma(-b_j)-\log(1-\hat{h}_j))\right] +\frac{1}{2}\sum\limits_{j=1}^n\left[\log\frac{\beta_j}{2\pi}-\beta_j\left(v_j^2-2v_j\boldsymbol{W}_{j,:}\hat{\boldsymbol{h}}+\sum\limits_k\left[W_{j,k}^2\hat{h}_k+\sum\limits_{l\ne k}W_{j,k}W_{j,l}\hat{h}_k\hat{h}_l\right]\right)\right]\right] &\quad\textbf{---\footnotesize{公式22}}\\\\ &=\log\sigma(b_i)-\log\hat{h}_i-1+\log(1-\hat{h}_i)+1-\log\sigma(-b_i)+\sum\limits_{j=1}^n\left[\beta_i\left(v_jW_{j,i}-\frac{1}{2}W_{j,i}^2-\sum\limits_{k\ne i}\boldsymbol{W}_{j,k}\boldsymbol{W}_{j,i}\hat{h}_k\right)\right] &\quad\textbf{---\footnotesize{公式23}}\\\\ &=b_i-\log\hat{h}_i+\log(1-\hat{h}_i)+\boldsymbol{v}^\top\beta\boldsymbol{W}_{:,i}-\frac{1}{2}\boldsymbol{W}_{:,i}^\top\beta\boldsymbol{W}_{:,i}-\sum\limits_{j\ne i}\boldsymbol{W}_{:,j}^\top\beta\boldsymbol{W}_{:,i}\hat{h}_j &\quad\textbf{---\footnotesize{公式24}} \end{aligned} \end{cases} h^iL(v,θ,h^)=h^i j=1m[h^j(logσ(bj)logh^j)+(1h^j)(logσ(bj)log(1h^j))]+21j=1n log2πβjβj vj22vjWj,:h^+k Wj,k2h^k+l=kWj,kWj,lh^kh^l =logσ(bi)logh^i1+log(1h^i)+1logσ(bi)+j=1n βi vjWj,i21Wj,i2k=iWj,kWj,ih^k =bilogh^i+log(1h^i)+vβW:,i21W:,iβW:,ij=iW:,jβW:,ih^j公式21公式22公式23公式24
    — 公式24 \quad\textbf{---\footnotesize{公式24}} 公式24

  • 为了应用固定点更新的推断规则,我们通过令公式24等于 0 0 0 来解 h ^ i \hat{h}_i h^i
    h ^ i = σ ( b i + v ⊤ β W : , i − 1 2 W : , i ⊤ β W : , i − ∑ j ≠ i W : , j ⊤ β W : , i h ^ j ) \hat{h}_i=\sigma\left(b_i+\boldsymbol{v}^\top\beta\boldsymbol{W}_{:,i}-\displaystyle\frac{1}{2}\boldsymbol{W}_{:,i}^\top\beta\boldsymbol{W}_{:,i}-\sum\limits_{j\ne i}\boldsymbol{W}_{:,j}^\top\beta\boldsymbol{W}_{:,i}\hat{h}_j\right) h^i=σ bi+vβW:,i21W:,iβW:,ij=iW:,jβW:,ih^j — 公式25 \quad\textbf{---\footnotesize{公式25}} 公式25

  • 此时,我们可以发现图模型中的推断和循环神经网络之间存在着紧密的联系。

    • 具体地说, 均匀场不动点方程定义了一个循环神经网络。
    • 这个神经网络的任务就是完成推断。
    • 我们已经从模型描述角度介绍了如何推导这个网络,但是直接训练这个推断网络也是可行的。
    • 有关这种思路的一些想法在后续篇章:深度生成模型中有所描述。
  • 在二值稀疏编码模型中,我们可以发现公式25中描述的循环网络连接包含了根据相邻隐藏单元变化值来反复更新当前隐藏单元的操作。

    • 输入层通常给隐藏单元发送一个固定的信息 v ⊤ β W \boldsymbol{v}^\top\beta\boldsymbol{W} vβW,然而隐藏单元不断地更新互相传送的信息。
    • 具体地说,当 h ^ i \hat{h}_i h^i h ^ j \hat{h}_j h^j 两个单元的权重向量对准时,他们会产生相互抑制。
    • 这也是一种形式的竞争——两个共同解释输入的隐藏单元之间,只有一个解释得更好的才被允许继续保持活跃。
    • 在二值稀疏编码的后验分布中, 均匀场近似为了捕获到更多的 explaining away \text{explaining away} explaining away作用,产生了这种竞争。
    • 事实上, explaining away \text{explaining away} explaining away效应会导致一个多峰值的后验分布,以至于我们如果从后验分布中采样,一些样本只有一个结点是活跃的,其他的样本在另一个结点活跃,只有很少的样本能够两者都处于活跃状态。
    • 不幸的是, explaining away \text{explaining away} explaining away作用无法通过均匀场中因子 q q q 分布来建模,因此建模时均匀场近似只能选择一个峰值。
    • 这个现象的一个例子可以参考图例1
  • 我们将公式25重写成等价的形式来揭示一些深层的含义:
    h ^ i = σ ( b i + ( v − ∑ j ≠ i W : , j h ^ j ) ⊤ β W : , i − 1 2 W : , i ⊤ β W : , i ) \hat{h}_i=\sigma\left(b_i+\left(\boldsymbol{v}-\sum\limits_{j\ne i}\boldsymbol{W}_{:,j}\hat{h}_j\right)^\top\beta\boldsymbol{W}_{:,i}-\frac{1}{2}\boldsymbol{W}_{:,i}^\top\beta\boldsymbol{W}_{:,i}\right) h^i=σ bi+(vj=iW:,jh^j)βW:,i21W:,iβW:,i — 公式26 \quad\textbf{---\footnotesize{公式26}} 公式26

  • 在这种新的形式中,我们可以将 v − ∑ j ≠ i W : , j h ^ j \boldsymbol{v}-\sum_{j\ne i}\boldsymbol{W}_{:,j}\hat{h}_j vj=iW:,jh^j 看做是输入,而不是 v \boldsymbol{v} v

    • 因此,我们可以把第 i i i 个单元视作给定其他单元编码时给 v \boldsymbol{v} v 中的剩余误差编码。
    • 由此我们可以将稀疏编码视作是一个迭代的自编码器,将输入反复地编码解码,试图在每一轮迭代后都能修复重构中的误差。
  • 在这个例子中,我们已经推导出了每一次更新单个结点的更新规则。

    • 如果能够同时更新更多的结点,那是非常好的。
    • 某些图模型,比如 DBM \text{DBM} DBM,我们可以同时解出 h ^ \hat{\boldsymbol{h}} h^ 中的许多元素。
    • 不幸的是,二值稀疏编码并不适用这种块更新。
    • 取而代之的是,我们使用一种称为是衰减 ( damping \text{damping} damping) 的启发式技巧来实现块更新。
    • 在衰减方法中,对 h ^ \hat{\boldsymbol{h}} h^ 中的每一个元素我们都可以解出最优值,然后对于所有的值都在这个方向上移动一小步。
    • 这个方法不能保证每一步都能增加 L \mathcal{L} L,但是对于许多模型都很有效。
    • 关于在信息传输算法中如何选择同步程度以及使用衰减策略可以参考 Koller and Friedman (2009) \text{Koller and Friedman (2009)} Koller and Friedman (2009)

变分法

  • 在继续描述变分学习之前,我们有必要简单地介绍一种变分学习中重要的数学工具: 变分法 ( calculus of variations \text{calculus of variations} calculus of variations)。

  • 许多机器学习的技巧是基于寻找一个输入向量 θ ∈ R n \boldsymbol{\theta}\in\mathbb{R}^n θRn 来最小化函数 J ( θ ) J(\boldsymbol{\theta}) J(θ),使得它取到最小值。

    • 这个步骤可以利用多元微积分以及线性代数的知识找到满足 ∇ θ J ( θ ) = 0 \nabla_{\boldsymbol{\theta}} J(\boldsymbol{\theta}) = 0 θJ(θ)=0 的临界点来完成。
    • 在某些情况下,我们希望能够解一个函数 f ( x ) f(\boldsymbol{x}) f(x),比如当我们希望找到一些随机变量的概率密度函数时。
    • 正是变分法能够让我们完成这个目标。
  • f f f 函数的函数被称为是泛函 ( functional \text{functional} functional) J [ f ] J[f] J[f]

    • 正如我们许多情况下对一个函数求关于以向量的元素为变量的偏导数一样,我们可以使用泛函导数 ( functional derivative \text{functional derivative} functional derivative),即在任意特定的 x \boldsymbol{x} x 值,对一个泛函 J [ f ] J[f] J[f] 求关于函数 f ( x ) f(\boldsymbol{x}) f(x) 的导数,这也被称为变分微分 ( variational derivative \text{variational derivative} variational derivative)。
    • 泛函 J J J 的关于函数 f f f 在点 x \boldsymbol{x} x 处的泛函导数被记作: δ δ f ( x ) J \displaystyle\frac{\delta}{\delta f(x)}\boldsymbol{J} δf(x)δJ
  • 完整正式的泛函导数的推导不在本文讨论的范围之内。为了满足我们的目标,讲述可微分函数 f ( x ) f(\boldsymbol{x}) f(x) 以及带有连续导数的可微分函数 g ( y , x ) g(y,\boldsymbol{x}) g(y,x) 就足够了:
    δ δ f ( x ) ∫ g ( f ( x ) , x ) d x = ∂ ∂ y g ( f ( x ) , x ) \displaystyle\frac{\delta}{\delta f(\boldsymbol{x})}\displaystyle\int g(f(\boldsymbol{x}),\boldsymbol{x})dx=\frac{\partial}{\partial y}g(f(\boldsymbol{x}),\boldsymbol{x}) δf(x)δg(f(x),x)dx=yg(f(x),x) — 公式27 \quad\textbf{---\footnotesize{公式27}} 公式27

  • 为了使上述的关系式更加形象,我们可以把 f ( x ) f(\boldsymbol{x}) f(x) 看做是一个有着无穷不可数多元素的向量,由一个实数向量 x \boldsymbol{x} x 表示。在这里(看做是一个不完全的介绍),这种关系式中描述的泛函导数和向量 θ ∈ R n \boldsymbol{\theta}\in\mathbb{R}^n θRn的导数相同:
    ∂ ∂ θ i ∑ j g ( θ j , j ) = ∂ ∂ θ i g ( θ i , i ) \displaystyle\frac{\partial}{\partial \theta_i}\sum\limits_j g(\theta_j,j)=\frac{\partial}{\partial \theta_i} g(\theta_i,i) θijg(θj,j)=θig(θi,i) — 公式28 \quad\textbf{---\footnotesize{公式28}} 公式28

  • 在其他机器学习文献中的许多结果是利用了更为通用的欧拉-拉格朗日方程 ( EulerLagrange Equation \text{EulerLagrange Equation} EulerLagrange Equation),它能够使得 g g g 不仅依赖于 f f f 的导数而且也依赖于 f f f 的值。但是本书中我们不需要完整地讲述这个结果。

  • 为了优化某个函数关于一个向量,我们求出了这个函数关于这个向量的梯度,然后找这个梯度中每一个元素都为 0 0 0 的点。类似的,我们可以通过寻找一个函数使得泛函导数的每个点都等于 0 0 0 从而来优化一个泛函。

  • 下面讲一个这个过程如何工作的例子,我们考虑寻找一个定义在 x ∈ R x \in \mathbb{R} xR 上的有最大微分熵的概率密度函数。我们回过头来看一下一个概率分布 p ( x ) p(x) p(x) 的熵,定义如下:
    H [ p ] = − E x log ⁡ p ( x ) H[p]=-\mathbb{E}_x\log p(x) H[p]=Exlogp(x) — 公式29 \quad\textbf{---\footnotesize{公式29}} 公式29

  • 对于连续的值,这个期望可以看成是一个积分:
    H [ p ] = − ∫ p ( x ) log ⁡ p ( x ) d x H[p]=-\displaystyle\int p(x)\log p(x)dx H[p]=p(x)logp(x)dx — 公式30 \quad\textbf{---\footnotesize{公式30}} 公式30

  • 我们不能简单地仅仅关于函数 p ( x ) p(x) p(x) 最大化 H [ p ] H[p] H[p],因为那样的话结果可能不是一个概率分布。

    • 为了解决这个问题,我们需要使用一个拉格朗日乘子来添加一个 p ( x ) p(x) p(x)积分值为 1 1 1 的约束。
    • 同样的,当方差增大时,熵也会无限制地增加。
    • 因此,寻找哪一个分布有最大熵这个问题是没有意义的。
    • 但是,在给定固定的方差 σ 2 \sigma^2 σ2 时,我们可以寻找一个最大熵的分布。
    • 最后,这个问题还是欠定的,因为在不改变熵的条件下一个分布可以被随意地改变。
    • 为了获得一个唯一的解,我们再加一个约束:分布的均值必须为 μ \mu μ
    • 那么这个问题的拉格朗日泛函可以被写成:
      { L [ p ] = λ 1 ( ∫ p ( x ) d x − 1 ) + λ 2 ( E [ x ] − μ ) + λ 3 ( E [ ( x − μ ) 2 ] − σ 2 ) + H [ p ] — 公式31 = ∫ ( λ 1 p ( x ) + λ 2 p ( x ) x + λ 3 p ( x ) ( x − μ ) 2 − p ( x ) log ⁡ p ( x ) ) d x − λ 1 − μ λ 2 − σ 2 λ 3 — 公式32 \begin{cases} \begin{aligned} \mathcal{L}[p]&=\lambda_1\left(\int p(x)dx-1\right)+\lambda_2(\mathbb{E}[x]-\mu)+\lambda_3(\mathbb{E}[(x-\mu)^2]-\sigma^2)+H[p] &\quad\textbf{---\footnotesize{公式31}}\\\\ &=\displaystyle\int \left(\lambda_1 p(x)+\lambda_2 p(x)x+\lambda_3 p(x)(x-\mu)^2-p(x)\log p(x) \right) dx-\lambda_1-\mu\lambda_2-\sigma^2\lambda_3 &\quad\textbf{---\footnotesize{公式32}} \end{aligned} \end{cases} L[p]=λ1(p(x)dx1)+λ2(E[x]μ)+λ3(E[(xμ)2]σ2)+H[p]=(λ1p(x)+λ2p(x)x+λ3p(x)(xμ)2p(x)logp(x))dxλ1μλ2σ2λ3公式31公式32
  • 为了关于 p p p 最小化拉格朗日乘子,我们令泛函导数等于 0 0 0
    ∀ x , δ δ p ( x ) L = λ 1 + λ 2 x + λ 3 ( x − μ ) 2 − 1 − log ⁡ p ( x ) = 0 \forall x,\displaystyle\frac{\delta}{\delta p(x)}\mathcal{L}=\lambda_1+\lambda_2x+\lambda_3(x-\mu)^2-1-\log p(x)=0 x,δp(x)δL=λ1+λ2x+λ3(xμ)21logp(x)=0 — 公式33 \quad\textbf{---\footnotesize{公式33}} 公式33

  • 这个条件告诉我们 p ( x ) p(x) p(x) 的泛函形式。通过代数运算重组上述方程,我们可以得到:
    p ( x ) = e ( λ 1 + λ 2 x + λ 3 ( x − μ ) 2 − 1 ) p(x)= e^{(\displaystyle\lambda_1+\lambda_2x+\lambda_3(x-\mu)^2-1)} p(x)=e(λ1+λ2x+λ3(xμ)21) — 公式34 \quad\textbf{---\footnotesize{公式34}} 公式34

  • 我们并没有直接假设 p ( x ) p(x) p(x) 取这种形式,而是通过最小化这个泛函从理论上得了这个 p ( x ) p(x) p(x) 的表达式。

    • 为了解决这个最小化问题,我们需要选择 λ \lambda λ 的值来确保所有的约束都能够满足。
    • 我们有很大的选择 λ \lambda λ 的自由。
    • 因为只要约束满足,拉格朗日关于 λ \lambda λ 这个变量的梯度为 0 0 0
    • 为了满足所有的约束,我们可以令 λ 1 = 1 − log ⁡ σ 2 π , λ 2 = 0 , λ 3 = − 1 2 σ 2 \lambda_1=1-\log\sigma\sqrt{2\pi},\lambda_2=0,\lambda_3=-\displaystyle\frac{1}{2\sigma^2} λ1=1logσ2π ,λ2=0,λ3=2σ21,从而得到:
      p ( x ) = N ( x ; μ , σ 2 ) p(x)=\mathcal{N}(x;\mu,\sigma^2) p(x)=N(x;μ,σ2) — 公式35 \quad\textbf{---\footnotesize{公式35}} 公式35
  • 这也是当我们不知道真实的分布时总是使用正态分布的一个原因。因为正态分布拥有最大的熵,我们通过这个假定来保证了最小可能量的结构。

  • 当寻找熵的拉格朗日泛函的临界点并且给定一个固定的方差时,我们只能找到一个对应最大熵的临界点。

    • 那最小化熵的概率密度函数是什么样的呢?
    • 为什么我们无法发现对应着极小点的第二个临界点呢?
    • 原因是没有一个特定的函数能够达到最小的熵值。
    • 当函数把越多的概率密度加到 x = μ + σ x = \mu + \sigma x=μ+σ x = μ − σ x = \mu − \sigma x=μσ 两个点上和越少的概率密度到其他点上时,他们的熵值会减少,而方差却不变。
    • 然而任何把所有的权重都放在这两点的函数的积分并不为 1 1 1,也不是一个有效的概率分布。
    • 所以不存在一个最小熵的概率密度函数,就像不存在一个最小的正实数一样。
    • 然而,我们发现存在一个收敛的概率分布的序列,收敛到权重都在两个点上。
    • 这种情况能够退化为混合 Dirac \text{Dirac} Dirac 分布。
    • 因为 Dirac \text{Dirac} Dirac分布并不是一个单独的概率密度函数,所以 Dirac \text{Dirac} Dirac分布或者混合 Dirac \text{Dirac} Dirac分布并不能对应函数空间的一个点。
    • 所以对我们来说,当寻找一个泛函导数为 0 0 0 的函数空间的点时,这些分布是不可见的。
    • 这就是这种方法的局限之处。
    • Dirac \text{Dirac} Dirac分布这样的分布可以通过其他方法被找到,比如可以被猜到,然后证明它是满足条件的。

连续性潜变量

  • 当我们的图模型包含连续型潜变量时,我们仍然可以通过最大化 L \mathcal{L} L 进行变分推断和学习。然而,我们需要使用变分法来实现关于 q ( h ∣ v ) q(\boldsymbol{h} \mid \boldsymbol{v}) q(hv) 最大化 L \mathcal{L} L

  • 在大多数情况下,研究者并不需要解决任何变分法的问题。取而代之的是, 均
    匀场固定点迭代更新有一种通用的方程。如果我们做了均匀场近似:
    q ( h ∣ v ) = ∏ i q ( h i ∣ v ) q(\boldsymbol{h} \mid \boldsymbol{v})=\prod\limits_i q(h_i\mid\boldsymbol{v}) q(hv)=iq(hiv) — 公式36 \quad\textbf{---\footnotesize{公式36}} 公式36

  • 并且对任何的 j ≠ i j \ne i j=i 固定了 q ( h j ∣ v ) q(h_j \mid \boldsymbol{v}) q(hjv),那么只需要满足 p p p 中任何联合分布中变量的概率值不为 0 0 0,我们就可以通过归一化下面这个未归一的分布:
    q ~ ( h i ∣ v ) = e ( E h − i ∼ q ( h − i ∣ v ) log ⁡ p ~ ( v , h ) ) \tilde{q}(h_i\mid v)=e^{\displaystyle(\mathbb{E}_{\textbf{h}_{-i}\sim q(\textbf{h}_{-i}\mid v)}\log \tilde{p}(\boldsymbol{v},\boldsymbol{h}))} q~(hiv)=e(Ehiq(hiv)logp~(v,h)) — 公式37 \quad\textbf{---\footnotesize{公式37}} 公式37

  • 来得到最优的 q ( h i ∣ v ) q(h_i \mid v) q(hiv)

    • 在这个方程中计算期望就能得到一个正确的 q ( h i ∣ v ) q(h_i \mid v) q(hiv) 的表达式。
    • 我们只有在希望提出一种新形式的变分学习算法时才需要使用变分法来直接推导 q q q 的函数形式。
    • 公式37给出了适用于任何概率模型的均匀场近似。
  • 公式37是一个不动点方程,对每一个 i i i 它都被迭代地反复使用直到收敛。

    • 然而,它还包含着更多的信息。
    • 它还包含了最优解取到的泛函形式,无论我们是否能够通过不动点方程来解出它。
    • 这意味着我们可以利用方程中的泛函形式,把其中一些值当成参数,然后通过任何我们想用的优化算法来解决这个问题。
  • 我们拿一个简单的概率模型作为例子,其中潜变量满足 h ∈ R 2 \boldsymbol{h}\in\mathbb{R}^2 hR2,可见变量只有一个 v v v

    • 假设 p ( h ) = N ( h ; 0 , I ) p(\boldsymbol{h}) = \mathcal{N}(\boldsymbol{h}; 0, \boldsymbol{I}) p(h)=N(h;0,I) 以及 p ( v ∣ h ) = N ( v ; w ⊤ h ; 1 ) p(v \mid \boldsymbol{h}) = N(v;\boldsymbol{w}^\top\boldsymbol{h}; 1) p(vh)=N(v;wh;1),我们可以通过把 h \boldsymbol{h} h 积掉来简化这个模型,结果是关于 v v v 的高斯分布。
    • 这个模型本身并不有趣。
    • 只是为了说明变分法如何应用在概率建模之中,我们才构造了这个模型。
  • 忽略归一化常数时,真实的后验分布可以给出:
    { p ( h ∣ v ) — 公式38 ∝ p ( h , v ) — 公式39 = p ( h 1 ) p ( h 2 ) p ( v ∣ h ) — 公式40 ∝ e ( − 1 2 [ h 1 2 + h 2 2 + ( v − h 1 w 1 − h 2 w 2 ) 2 ] ) — 公式41 = e ( − 1 2 [ h 1 2 + h 2 2 + v 2 + h 1 2 w 1 2 + h 2 2 w 2 2 − 2 v h 1 w 1 − 2 v h 2 w 2 + 2 h i w i h 2 w 2 ) — 公式42 \begin{cases} \begin{aligned} &\quad p(\boldsymbol{h}\mid \boldsymbol{v}) &\quad\textbf{---\footnotesize{公式38}}\\ &\propto p(\boldsymbol{h}, \boldsymbol{v}) &\quad\textbf{---\footnotesize{公式39}}\\ &=p(h_1)p(h_2)p(\boldsymbol{v}\mid\boldsymbol{h}) &\quad\textbf{---\footnotesize{公式40}}\\ &\propto e^{\left(-\displaystyle\frac{1}{2}[h_1^2+h_2^2+(v-h_1w_1-h_2w_2)^2]\right)} &\quad\textbf{---\footnotesize{公式41}}\\ &=e^{\left(-\displaystyle\frac{1}{2}[h_1^2+h_2^2+v^2+h_1^2w_1^2+h_2^2w_2^2-2vh_1w_1-2vh_2w_2+2h_iw_ih_2w_2\right)} &\quad\textbf{---\footnotesize{公式42}} \end{aligned} \end{cases} p(hv)p(h,v)=p(h1)p(h2)p(vh)e(21[h12+h22+(vh1w1h2w2)2])=e(21[h12+h22+v2+h12w12+h22w222vh1w12vh2w2+2hiwih2w2)公式38公式39公式40公式41公式42

  • 在上式中,我们发现由于带有 h 1 h_1 h1 h 2 h_2 h2 乘积项的存在,真实的后验并不能将 h 1 h_1 h1 h 2 h_2 h2 分开。

  • 应用公式37,我们可以得到:
    { q ~ ( h 1 ∣ v ) — 公式43 = e ( E h 2 ∼ q ( h 2 ∣ v ) log ⁡ p ~ ( v , h ) ) — 公式44 = e ( − 1 2 E h 2 ∼ q ( h 2 ∣ v ) [ h 1 2 + h 2 2 + v 2 + h 1 2 w 1 2 + h 2 2 w 2 2 − 2 v h 1 w 1 − 2 v h 2 w 2 + 2 h 1 h 2 w 2 ] ) — 公式45 \begin{cases} \begin{aligned} &\quad\tilde{q}(h_1\mid\boldsymbol{v}) &\quad\textbf{---\footnotesize{公式43}}\\ &=e^{\displaystyle\left(\mathbb{E}_{\text{h}_2\sim q(h_2\mid \boldsymbol{v})}\log \tilde{p}(\boldsymbol{v},\boldsymbol{h})\right)} &\quad\textbf{---\footnotesize{公式44}}\\ &=e^{\left(-\displaystyle\frac{1}{2}\mathbb{E}_{\text{h}_2\sim q(\text{h}_2\mid \boldsymbol{v})}[h_1^2+h_2^2+v^2+h_1^2w_1^2+h_2^2w_2^2-2vh_1w_1-2vh_2w_2+2h_1h_2w_2]\right)} &\quad\textbf{---\footnotesize{公式45}} \end{aligned} \end{cases} q~(h1v)=e(Eh2q(h2v)logp~(v,h))=e(21Eh2q(h2v)[h12+h22+v2+h12w12+h22w222vh1w12vh2w2+2h1h2w2])公式43公式44公式45

  • 从这里,我们可以发现其中我们只需要从 q ( h 2 ∣ v ) q(h_2\mid\boldsymbol{v}) q(h2v)中获得两个有效值: E h 2 ∼ q ( h ∣ v ) [ h 2 ] \mathbb{E}_{\text{h}_2\sim q(\text{h}\mid v)}[h_2] Eh2q(hv)[h2] E h 2 ∼ q ( h ∣ v ) [ h 2 2 ] \mathbb{E}_{\text{h}_2\sim q(\text{h}\mid v)}[h_2^2] Eh2q(hv)[h22]。把这两项记作 〈 h 2 〉 〈h_2〉 h2 〈 h 2 2 〉 〈h_2^2〉 h22,我们可以得到:
    q ~ ( h 1 ∣ v ) = e ( − 1 2 [ h 1 2 + 〈 h 2 2 〉 + v 2 + h 1 2 w 1 2 + 〈 h 2 2 〉 w 2 2 − 2 v h 1 w 1 − 2 v 〈 h 2 〉 w 2 + 2 h 1 w 1 〈 h 2 〉 w 2 ] ) \tilde{q}(h_1\mid \boldsymbol{v})=e^{\left(-\displaystyle\frac{1}{2}[h_1^2+〈h_2^2〉+v^2+h_1^2w_1^2+〈h_2^2〉w_2^2-2vh_1w_1-2v〈h_2〉w_2+2h_1w_1〈h_2〉w_2]\right)} q~(h1v)=e(21[h12+h22+v2+h12w12+h22w222vh1w12vh2w2+2h1w1h2w2]) — 公式46 \quad\textbf{---\footnotesize{公式46}} 公式46

  • 从这里,我们可以发现 q ~ \tilde{q} q~泛函形式满足高斯分布。

    • 因此,我们可以得到 q ( h ∣ v ) = N ( h ; μ , β − 1 ) q(\boldsymbol{h} \mid\boldsymbol{v}) = \mathcal{N}(\boldsymbol{h}; \mu, \beta^{-1}) q(hv)=N(h;μ,β1),其中 μ \mu μ 和对角的 β \beta β 是变分参数,我们可以使用任何方法来优化它。
    • 有必要再强调一下,我们并没有假设 q q q 是一个高斯分布,这个高斯的形式是使用变分法来最大化 q q q 关于 L \mathcal{L} L推导出的。
    • 在不同的模型上应用相同的方法可能会得到不同泛函形式的 q q q 分布。
  • 当然,上述模型只是为了说明情况的一个简单例子。 深度学习中关于变分学习中连续型变量的实际应用可以参考 Goodfellow et al. (2013f) \text{Goodfellow et al. (2013f)} Goodfellow et al. (2013f)

学习和推断之间的相互作用

  • 在学习算法中使用近似推断会影响学习的过程,反过来这也会影响推断算法的准确性。

  • 具体来说,训练算法倾向于以使得近似推断算法中的近似假设变得更加真实的方向来适应模型。当训练参数时,变分学习增加:
    E h ∼ q log ⁡ p ( v , h ) \mathbb{E}_{\textbf{h}\sim q}\log p(\boldsymbol{v},\boldsymbol{h}) Ehqlogp(v,h) — 公式47 \quad\textbf{---\footnotesize{公式47}} 公式47

  • 对于一个特定的 v \boldsymbol{v} v,对于 q ( h ∣ v ) q(\boldsymbol{h} \mid \boldsymbol{v}) q(hv) 中概率很大的 h \boldsymbol{h} h 它增加了 p ( h ∣ v ) p(\boldsymbol{h} \mid \boldsymbol{v}) p(hv),对于 q ( h ∣ v ) q(\boldsymbol{h} \mid \boldsymbol{v}) q(hv) 中概率很小的 h \boldsymbol{h} h 它减小了 p ( h ∣ v ) p(\boldsymbol{h} \mid \boldsymbol{v}) p(hv)

  • 这种行为使得我们做的近似假设变得合理。如果我们用单峰值近似后验来训练模型,我们将获得一个真实后验的模型,该模型比我们通过使用精确推断训练模型获得的模型更接近单峰值。

  • 因此,估计由于变分近似对模型产生的伤害大小是很困难的。存在几种估计 log ⁡ p ( v ) \log p(\boldsymbol{v}) logp(v) 的方式。

    • 通常我们在训练模型之后估计 log ⁡ p ( v ; θ ) \log p(\boldsymbol{v};\boldsymbol{\theta}) logp(v;θ),然后发现它和 L ( v , θ , q ) \mathcal{L}(\boldsymbol{v},\boldsymbol{\theta}, q) L(v,θ,q) 的差距是很小的。
    • 从这里我们可以得出结论,对于特定的从学习过程中获得的 θ \boldsymbol{\theta} θ 来说,变分近似是很准确的。
    • 然而我们无法直接得到变分近似普遍很准确或者变分近似几乎不会对学习过程产生任何负面影响这样的结论。
    • 为了准确衡量变分近似带来的危害,我们需要知道 θ ∗ = max ⁡ θ log ⁡ p ( v ; θ ) \boldsymbol{\theta}^\ast = \max_{\boldsymbol{\theta}} \log p(\boldsymbol{v};\boldsymbol{\theta}) θ=maxθlogp(v;θ)
    • L ( v , θ , q ) ≈ log ⁡ p ( v ; θ ) \mathcal{L}(\boldsymbol{v},\boldsymbol{\theta}, q) \approx \log p(\boldsymbol{v};\boldsymbol{\theta}) L(v,θ,q)logp(v;θ) log ⁡ p ( v ; θ ) ≪ log ⁡ p ( v ; θ ∗ ) \log p(\boldsymbol{v};\boldsymbol{\theta}) \ll \log p(\boldsymbol{v};\boldsymbol{\theta}^\ast) logp(v;θ)logp(v;θ)同时成立是有可能的。
    • 如果存在 max ⁡ q L ( v , θ ∗ , q ) ≪ log ⁡ p ( v ; θ ∗ ) \max_q \mathcal{L}(\boldsymbol{v},\boldsymbol{\theta}^\ast, q) \ll \log p(\boldsymbol{v};\boldsymbol{\theta}^\ast) maxqL(v,θ,q)logp(v;θ),即在 θ ∗ \boldsymbol{\theta}^\ast θ 点处后验分布太过复杂使得 q q q 分布族无法准确描述,则我们无法学习到一个 θ \boldsymbol{\theta} θ
    • 这样的一类问题是很难发现的,因为只有在我们有一个能够找到 θ ∗ \boldsymbol{\theta}^\ast θ 的超级学习算法时,才能确定地进行上述的比较。

总结

变分推断作为一种高效、可扩展的近似推断方法,在数据科学和机器学习领域展现出了巨大的潜力。它不仅降低了计算成本,提高了推断效率,还与学习算法紧密结合,实现了参数的准确估计和模型的优化。随着大数据时代的到来和机器学习技术的不断发展,变分推断将在更多领域发挥重要作用,推动人工智能技术的不断进步。

总之,变分推断与学习算法的结合,为我们提供了一种处理大规模数据和复杂模型的有效方法。未来,随着相关理论的不断完善和技术的持续发展,变分推断将在人工智能领域发挥更加重要的作用。

往期内容回顾

应用数学与机器学习基础 - 概率与信息论篇

猜你喜欢

转载自blog.csdn.net/benny_zhou2004/article/details/143194905