From GAN to WGAN

From GAN to WGAN

本文解释了生成对抗网络(GAN)[1]模型背后的数学原理,以及为什么它很难被训练。Wasserstein GAN旨在通过采用平滑度量来衡量两个概率分布之间的距离来改善GAN的训练。

1 Introduction

生成式对抗网络(GAN)[1]在许多生成式任务中显示了巨大的成果,以复制现实世界的丰富内容,如图像、人类语言和音乐。它受到博弈论的启发:两个模型,一个生成器和一个批判者,在相互竞争的同时使对方变得更强大。然而,训练GAN模型是相当具有挑战性的,因为人们正面临着训练不稳定或无法收敛等问题。在这里,我想解释生成式对抗网络框架背后的数学,为什么它很难被训练,最后介绍一个旨在解决训练困难的GAN的修改版本。

2 Kullback–Leibler and Jensen–Shannon Divergence

在我们开始仔细研究GANs之前,让我们首先回顾一下量化两个概率分布之间相似性的两个指标。

(1) KL (Kullback–Leibler) Divergence

衡量一个概率分布p如何偏离第二个预期概率分布q:

在这里插入图片描述

当p(x)==q(x)无处不在时, D K L D_{KL} DKL实现了最小零点。

根据该公式可以看出,KL发散是不对称的。在p(x)接近于零,但q(x)明显不为零的情况下,q的影响被忽略了。当我们只是想测量两个同样重要的分布之间的相似性时,这可能会导致错误的结果。

(2) Jensen–Shannon Divergence 是另一种衡量两个概率分布之间相似性的方法,以[0, 1]为界。JS发散是对称的,而且更加平滑。如果你有兴趣阅读更多关于KL发散和JS发散之间的比较,请查看这篇文章。

在这里插入图片描述

有些人[2]认为,GANs大获成功的原因之一是将损失函数从传统的最大似然法中的非对称KL发散改为对称JS发散。

在这里插入图片描述

图1:给定两个高斯分布,p的平均值=0,std=1,q的平均值=1,std=1.两个分布的平均值被标记为m=(p+q)/2。KL发散 D K L D_{KL} DKL是不对称的,但JS发散 D J S D_{JS} DJS是对称的。

3 Generative Adversarial Network

GAN包含两个模型:

  • 鉴别器D估计一个给定样本来自真实数据集的概率。它像一个批评家critic 一样工作,并被优化以区分虚假样本和真实样本。

  • 一个生成器G在噪声变量输入z的情况下输出合成样本(z带来潜在的输出多样性)。它被训练来捕捉真实的数据分布,以便其生成的样本可以尽可能的真实,或者换句话说,可以欺骗判别器提供一个高概率。

在这里插入图片描述

这两个模型在训练过程中相互竞争:生成器G努力欺骗鉴别器,而批评者模型D则努力不被欺骗。两个模型之间的这种有趣的零和博弈促使两个模型都要提高自己的功能。

在这里插入图片描述

一方面,我们希望通过最大化 E x ∼ p r ( x ) [ l o g D ( x ) ] \mathbb E_{x∼p_r(x)}[log D(x)] Expr(x)[logD(x)]来确保判别器D对真实数据的决定是准确的。同时,给定一个假的样本G(z), z ∼ p z ( z ) z∼p_z(z) zpz(z),判别器有望通过最大化 E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ] \mathbb E_{z∼p_z(z)}[log(1-D(G(z))] Ezpz(z)[log(1D(G(z))]输出一个接近零的概率D(G(z))。

另一方面,生成器被训练成增加D产生高概率的假例子的机会,从而使 E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \mathbb E_{z∼p_z(z)}[log(1-D(G(z)))] Ezpz(z)[log(1D(G(z)))]最小。

当把这两个方面结合在一起时,D和G在玩一个minimax的游戏,在这个游戏中我们应该优化以下损失函数。

在这里插入图片描述

3.1 What is the Optimal Value for D?

现在我们有了一个定义明确的损失函数。让我们首先研究一下什么是D的最佳值

在这里插入图片描述

由于我们感兴趣的是什么是D(x)的最佳值,以使L(G, D)最大化,让我们标记为

在这里插入图片描述

然后积分里面的内容(我们可以安全地忽略积分,因为x是在所有可能的值上采样的)是:

在这里插入图片描述

3.2 What is the Global Optimal?

当G和D都处于最佳值时,我们有 p g = p r 和 D ∗ ( x ) = 1 / 2 p_g=p_r和D^∗(x)=1/2 pg=prD(x)=1/2,损失函数变为。

在这里插入图片描述

3.3 What does the Loss Function Represent?

根据第2节所列的公式, p r 和 p g p_r和p_g prpg之间的JS散度可以计算为:

在这里插入图片描述

从本质上讲,当鉴别器最优时,GAN的损失函数通过JS散度来量化生成数据分布 p g p_g pg和真实样本分布 p r p_r pr之间的相似性。复制真实数据分布的最佳G∗导致最小 L ( G ∗ , D ∗ ) = − 2 l o g 2 L(G^∗, D^∗) = -2 log 2 L(G,D)=2log2,这与上面的公式一致。

4 Problems in GANs

尽管GAN在现实的图像生成中显示了巨大的成功,但训练并不容易;众所周知,这个过程是缓慢和不稳定的。

4.1 Hard to Achieve Nash Equilibrium

[3] 讨论了基于GAN梯度下降法的训练过程。同时对两个模型进行训练,得到两人非合作博弈的纳什均衡。但是,每个模型独立地更新其成本,而不考虑游戏中的其他玩家。同时更新两个模型的梯度不能保证收敛。

让我们看看一个简单的例子,以更好地理解为什么在非合作博弈中很难找到纳什均衡。假设一个玩家控制x,以最小化 f 1 ( x ) = x y f_1(x)=xy f1(x)=xy,同时另一个玩家不断更新y,以最小化 f 2 ( y ) = − x y f_2(y)=-xy f2(y)=xy

因为 ∂ f 1 / ∂ x = y , ∂ f 2 / ∂ y = − x , ∂f_1/∂x=y,∂f_2/∂y=-x, f1/x=yf2/y=x我们在一次迭代中同时用 x − η ⋅ y x-η\cdot y xηy更新x,用 y + η ⋅ x y+η\cdot x y+ηx更新y,其中η是学习率。一旦x和y的符号不同,接下来的每一次梯度更新都会引起巨大的震荡,不稳定性会越来越严重,如图3所示。

在这里插入图片描述

4.2 Low Dimensional Supports

在这里插入图片描述

[4]讨论了pr和pg的supports位于低维流形上的问题,以及它是如何彻底导致GAN训练的不稳定性的。

许多现实世界的数据集的维度,如pr所代表的,only appear to be artificially high。它们被发现集中在一个较低维度的流形中。这实际上是流形学习的基本假设。想想现实世界的图像,一旦主题或包含的对象被固定下来,图像就有很多限制要遵循,例如,狗应该有两只耳朵和一条尾巴,摩天大楼应该有一个笔直高大的身体,等等。这些限制使图像远离了拥有高维自由形式的可能性。

pg也位于一个低维流形中。每当生成器被要求对一个更大的图像,如64x64,给定一个小的维度,如100,噪声变量输入z,在这4096个像素上的颜色分布已经被100维的小随机数向量所定义,很难填补整个高维空间。

因为pg和pr都在低维流形中,它们几乎肯定会是不相交的(见图4)。当它们有不相交的supports时,我们总是能够找到一个完美的判别器,100%正确地分离出真实和虚假的样本。[4]

在这里插入图片描述

4.3 Vanishing Gradient

当鉴别器完美时,我们可以保证 D ( x ) = 1 , ∀ x ∈ p r , D ( x ) = 0 , ∀ x ∈ p g D(x)=1,∀x∈p_r,D(x)=0,∀x∈p_g D(x)=1xprD(x)=0xpg。因此损失函数L下降到零,我们最终没有梯度来更新学习迭代中的损失。图5展示了一个实验,当判别器变得更好时,梯度快速消失。

在这里插入图片描述

图5:首先,DCGAN被训练为1、10和25个epoch。然后,在生成器固定的情况下,从头开始训练一个判别器,用原始成本函数测量梯度。我们看到梯度规范迅速衰减(按对数比例),在最好的情况下,经过4000次判别器的迭代,梯度规范衰减了5个数量级。图片来源:[4])。

  • 如果鉴别器表现不好,生成器就没有准确的反馈,损失函数就不能代表现实。
  • 如果鉴别器做得很好,损失函数的梯度就会下降到接近零,学习就会变得超慢甚至 jammed。
  • 这种困境显然能够使GAN的训练变得非常艰难。

4.4 Mode Collapse

在训练过程中,生成器可能会崩溃到一个总是产生相同输出的环境中。这是GANs的一个常见失败案例,通常被称为模式崩溃。即使生成器可能能够欺骗相应的判别器,但它未能学会代表复杂的现实世界数据分布,而是陷入了一个种类极少的小空间。

4.5 Lack of a Proper Evaluation Metric

生成式对抗网络不是天生就有一个好的反对函数,可以告诉我们训练的进展。没有一个好的评估指标,就像在黑暗中工作。没有好的标志来告诉我们何时停止;没有好的指标来比较多个模型的性能。

5 Improved GAN Training

以下建议有助于稳定和改善GANs的训练。
前五种方法是实现GAN训练快速收敛的实用技术[3]。最后两个是在[4]中提出的,用于解决不相交分布的问题。

(1) Feature Matching

特征匹配建议优化判别器,以检查生成器的输出是否与真实样本的预期统计数据相匹配。在这种情况下,新的损失函数被定义为 ∣ ∣ E x ∼ p r f ( x ) − E z ∼ p z ( z ) f ( G ( z ) ) ∣ ∣ 2 2 ||\mathbb E_{x\sim p_r}f(x)-\mathbb E_{z\sim p_z(z)}f(G(z))||_2^2 Exprf(x)Ezpz(z)f(G(z))22,其中f(x)可以是特征统计的任何计算,如平均值或中位数。

(2) Minibatch Discrimination 通过minibatch discrimination,判别器能够在一个批次中消化训练数据点之间的关系,而不是独立处理每个点。

在一个minibatch中,我们对每一对样本之间的接近程度进行近似计算, c ( x i , x j ) c(x_i, x_j ) c(xi,xj),并通过对一个数据点与同批其他样本的接近程度进行加总,得到一个数据点的总体总结, o ( x i ) = Σ j c ( x i , x j ) o(x_i) = \Sigma_jc(x_i, x_j ) o(xi)=Σjc(xi,xj) 。然后, o ( x i ) o(x_i) o(xi)显式地添加到模型的输入。

(3) Historical Averaging

对于这两个模型,在损失函数中加入 ∣ ∣ Θ − 1 t Σ i = 1 t Θ i ∣ ∣ 2 ||Θ-\frac{1}{t}\Sigma^t_{i=1}Θ_i||^2 Θt1Σi=1tΘi2,其中Θ是模型参数,Θi是参数在过去的训练时间i的参数。当Θ在时间上变化过大时,这个添加会惩罚训练速度。

(4) One-sided Label Smoothing

在feeding判别器时,不要提供1和0的标签,而是使用soften values,如0.9和0.1。事实证明,它可以减少网络的脆弱性。

(5) Virtual Batch Normalization (VBN)

每个数据样本都是基于一个固定的批次(“reference batch”)的数据而不是在其minibatch内进行标准化(normalized)。reference batch在开始时选择一次,并在训练过程中保持不变。

(6) Adding Noises

根据第4.2节的讨论,我们现在知道pr和pg在高维空间中是不相交的,这就造成了梯度消失的问题。为了人为地 "分散 "分布,并为两个概率分布的重叠创造更多的机会,一种解决办法是在判别器D的输入上添加连续噪声。

(7) Use Better Metric of Distribution Similarity

vanilla GAN的损失函数衡量pr和pg的分布之间的JS散度。当两个分布不相交时,这个指标无法提供一个有意义的值。

提议用Wasserstein度量代替JS发散,因为它有一个更平滑的值空间。详情见下一节。

6 Wasserstein GAN (WGAN)

6.1 What is Wasserstein Distance?

Wasserstein距离是对两个概率分布之间距离的测量。它也被称为 “搬运工距离”,简称EM距离,因为它可以非正式地解释为将一个概率分布形状的一堆泥土移动并转化为另一个分布形状的最小能量成本。该成本被量化为:移动的泥土量x移动距离。

让我们先看一下概率域是离散的一个简单情况。例如,假设我们有两个分布P和Q,每个分布都有四堆泥土,并且都有十铲子泥土。每个土堆中的铲子的数量分配如下。

在这里插入图片描述

为了把P变成Q的样子,如图7所示,我们。

  • 首先将2个铲子从P1移到P2 => (P1, Q1) 匹配。
  • 然后将2个铲子从P2移到P3 => (P2, Q2)匹配。
  • 最后将1个铲子从Q3移到Q4=>(P3,Q3)和(P4,Q4)相匹配。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fBZu2TwJ-1622282298006)(018.jpg)]

如果我们把使Pi和Qi匹配所付出的成本标记为δi,我们将有 δ i + 1 = δ i + P i − Q i δ_{i+1}=δ_i+P_i-Q_i δi+1=δi+PiQi,而在例子中

在这里插入图片描述

当处理连续概率域时,距离公式变成了:

在这里插入图片描述
在上面的公式中, Π ( p r , p g ) Π(p_r, p_g) Π(pr,pg) p r 和 p g p_r和p_g prpg之间所有可能的联合概率分布的集合。一个联合分布 γ ∈ Π ( p r , p g ) γ∈Π(p_r, p_g) γΠ(pr,pg)描述了一个泥土运输计划,与上面的离散例子相同,但在连续概率空间中。这就是为什么x上的边际分布加起来是 p g , Σ x γ ( x , y ) = p g ( y ) p_g,\Sigma_x γ(x, y) = p_g(y) pgΣxγ(x,y)=pg(y) (Once we finish moving the planned amount of dirt from every possible x to the target y, we end up with exactly what y has according to p g p_g pg.)反之 Σ y γ ( x , y ) = p r ( x ) \Sigma _y γ(x, y) = p_r(x) Σyγ(x,y)=pr(x)

当把x当作起点,y当作目的地时,移动的泥土总量为 γ ( x , y ) γ(x,y) γ(xy),行驶距离为 ∣ ∣ x − y ∣ ∣ ||x - y|| xy,因此成本为 γ ( x , y ) ⋅ ∣ ∣ x − y ∣ ∣ γ(x,y)\cdot||x - y|| γ(xy)xy。所有(x,y)对的平均预期成本可以很容易地计算为:

在这里插入图片描述
最后,我们取所有泥土移动方案中成本最小的一个作为EM距离。在Wasserstein distance的定义中,inf(infimum,也称为greatest lower bound)表示我们只对最小的成本感兴趣。

6.2 Why Wasserstein is better than JS or KL Divergence?

即使两个分布位于没有重叠的低维流形中,Wasserstein距离仍然可以提供一个有意义的、平滑的中间距离的表示。

WGAN的论文用一个简单的例子说明了这个想法。
假设我们有两个概率分布,P和Q:
在这里插入图片描述
在这里插入图片描述
当两个分布不相交时, D K L D_{KL} DKL给了我们无穷大。只有Wasserstein度量提供了一个平滑的度量,这对使用梯度下降的稳定学习过程有很大帮助。

6.3 Use Wasserstein Distance as GAN Loss Function

穷尽 Π ( p r , p g ) Π(p_r, p_g) Π(pr,pg)中所有可能的联合分布来计算 i n f γ ∼ Π ( p r , p g ) inf_{γ∼Π(pr,pg)} infγΠ(pr,pg)是难以实现的。因此,作者提出了一种基于Kantorovich-Rubinstein二元性的公式的智能转换,即(Thus the authors proposed a smart transformation of the formula based on the Kantorovich-Rubinstein duality to)
在这里插入图片描述
其中sup(上确界)与inf(下确界)相反;我们要测量最小上界,或者更简单的说,最大值(where sup (supremum) is the opposite of inf (infimum); we want to measure the least upper bound or, in even simpler words, the maximum value.)。

6.3.1 Lipschitz Continuity

新形式的Wasserstein度量中的函数f被要求满足 ∣ ∣ f ∣ ∣ L ≤ K ||f||_L\le K fLK,意味着它应该是K-Lipschitz连续的。

一个实值函数f : 如果存在一个实数常数K≥0,使得对于所有x1、x2∈R,则称为K-Lipschitz连续。

K-Lipschitz连续即,实值函数f:$\mathbb R→\mathbb R 如 果 存 在 实 常 数 如果存在实常数 K\ge 0KaTeX parse error: Expected 'EOF', got '#' at position 77: …53117324928.JPG#̲pic_center) 这里K…\frac{|f(x_1)-f(x_2)|}{|x_1-x_2|}$,有界限。然而,一个Lipschitz连续函数不一定是到处可微的,如f(x)=|x|。

解释如何在Wasserstein距离公式上发生转换本身就值得写一篇长文章,所以我在这里跳过了细节。如果你对如何使用线性编程计算Wasserstein公式感兴趣,或者对如何根据Kantorovich-Rubinstein二元性将Wasserstein公制转移到其对偶形式感兴趣,请阅读这篇很棒的文章。

6.3.2 Wasserstein Loss Function

假设这个函数f来自一个K-Lipschitz连续函数系列, { f w } w ∈ W \{f_w\}_{w∈\mathcal W} { fw}wW,参数为w。在改进的Wasserstein-GAN中,"判别器 "模型被用来学习w以找到一个好的fw,损失函数被配置为测量pr和pg之间的Wasserstein距离。

在这里插入图片描述
因此,"鉴别器 "不再是一个直接的critic 来区分假样本和真样本了。相反,它被训练为学习一个K-Lipschitz连续函数来帮助计算Wasserstein距离。随着训练中损失函数的减少,Wasserstein距离就会变小,生成器模型的输出就会越来越接近真实数据的分布。

一个大问题是在训练过程中保持 f w f_w fw的K-Lipschitz连续性,以使一切顺利进行。本文提出了一个简单但非常实用的技巧。在每次梯度更新后,将权重w夹在一个小窗口中,如[-0.01, 0.01],从而形成一个紧凑的参数空间W,从而使 f w f_w fw获得其下限和上限以保持Lipschitz的连续性。
在这里插入图片描述
与原始的GAN算法相比,WGAN进行了以下改变。

  • 在每次对批判函数进行梯度更新后,将权重钳制在一个小的固定范围内,[-c, c]。
  • 使用从Wasserstein距离衍生出来的新的损失函数,不再使用对数。"鉴别器 "模型不作为直接的critic,而是作为估计真实和生成的数据分布之间的Wasserstein度量的帮助者。
  • 根据经验,作者建议在批判者上使用RMSProp优化器,而不是基于动量的优化器,如Adam,因为它可能导致模型训练的不稳定性。我还没有看到关于这一点的明确的理论解释。

不幸的是,Wasserstein GAN并不完美。即使是WGAN原始论文的作者也提到,“权重剪裁显然是执行Lipschitz约束的糟糕方式”。WGAN仍然存在训练不稳定、权重剪裁后收敛缓慢(当剪裁窗口太大)和梯度消失(当剪裁窗口太小)的问题。

一些改进,确切地说是用梯度惩罚代替权重剪裁,已在[6]中讨论。

猜你喜欢

转载自blog.csdn.net/weixin_37958272/article/details/117388479
GAN
今日推荐