如何直观的解释back propagation算法(二)


利用计算图做自动微分时,既有前向模式,也有反向模式。而神经网络中的反向传播就是自动微分的反向模式。事实上,我们还可以用“前向传播”来计算神经网络中的梯度值,但是由于效率原因这个方法并没有被采用。

我们首先考虑下面这个计算图

我们的终极目标是为了计算输出对输入的梯度,即\partial L/\partial x (注意,这里输出L是一维的,即实数,因此是梯度。而x,y,z可以是多维的向量。这个维度的差异非常关键)。y与z可以看做中间状态变量,数学上并不是必须存在的,完全可以去掉它们,直接考虑L(x) 这个函数及其梯度。但是为了自动求梯度,我们把x到L的计算过程拆分成一些基本运算(如加减乘除)的序列执行,这样相邻两层之间的梯度就只有有限几种简单的形式,从而可以结合链式法则自动计算梯度。

由于这样的拆分,我们预先就知道了相邻两层的微分(例如dz/dy, dy/dx )的形式。接下来有两种模式来计算\partial L/\partial x

前向模式:迭代的每一步是从输入层开始往上,计算当前迭代所在的层对输入x的微分。例如假设我们计算到了z,那么在当前这一步我们需要计算的是dz/dx=dz/dy\cdot dy/dx 。其中dy/dx 在迭代的上一步已经计算出来,dz/dy 的形式我们预先知道,所缺的就只有y的值。所以前向过程是先把前向传播,然后将值代入相邻两层的微分计算出微分的值,最后和上一步的微分的值相乘(矩阵相乘),得到当前层的微分的值,从而微分的值也被前向传播。如此这般,到最后我们就能计算得到\partial L/\partial x=(dL/dx)^T

反向模式:迭代的每一步是从输出层开始往下,计算输出L对当前层的微分。例如假设我们计算到了y,那么在当前这一步我们需要计算的是dL/dy=dL/dz\cdot dz/dy 。其中dL/dz 在上一步已经计算得到,dz/dy 的形式我们预先知道,所缺的只有y的值。因此我们在反向传播微分之前,需要先做一次前向的值的计算,把x,y,z,L 都计算出来,这样在反向过程中才能计算出每一步微分的值。当前层的微分值计算出来后,可以继续迭代从而反向传播到输入,得到dL/dx

接下来我们比较这两种模式的优缺点:

1. 空间复杂度:前向模式中,微分的计算只需要当前层的输入层的值,完全不需要以前各层的值,因此只需要维护两个变量:当前层对输入层的微分值,以及当前层的值,占用量相当于O(1)。反向模式则不同,需要先前向计算一遍值并且把所有层的值保存下来,用于后续各层微分值的计算。所以反向模式空间占用量相当于O(n),其中n是计算图输入到输出路径的长度。

2. 时间复杂度:计算上看,前向模式只需要前向传播一次就可以得到输出值以及\partial L/\partial x 的值。尽管每一步迭代实际上需要同时计算值和微分值,但完全可以并行计算在时间上达到O(n)。反向模式需要先前向传播值,然后反向传播微分值,二者不能并行,因此时间是O(2n),量级上一致。但多常数倍时间。然而,这个分析是错误的,因为它忽略了非常重要的一点:在我们的问题里,输出是1维的,而输入以及中间各层是多维的。只看微分值的计算,我们事实上是在做下图的矩阵计算:


前向模式相当于从右边开始一步步计算矩阵乘矩阵,计算复杂度是O(m^3) ;反向模式相当于从左边一步步计算向量乘矩阵,计算复杂度是O(m^2) 。因此时间复杂度上,反向传播在我们的问题中更胜一筹。

如果我们考虑神经网络的结构,那么前向模式在空间占用上的优势甚至进一步丧失殆尽。我们考虑最简单的DNN。尽管通常画DNN结构图时都会画出类似我们一开始那样的结构,但是从计算图来看,DNN实质上是一个的结构:

与一开始的结构不同,在这里我们的目标是要计算输出对所有W_x,W_y,W_z,... 的微分值。在这个情境下,我们再来看两种模式在这种结构下的差异:

1. 前向模式:在每一层,需要维护当前层对之前所有权值的微分值dz/d{W_x},dz/d{W_y}... ,这样才可能在传播到损失函数时得到损失函数对所有这些权值的微分值。由此带来空间复杂度的增加。同时这样还引入了重复计算。

2. 反向模式:在每一层,需要维护输出对当前层的微分值,该微分值相当于被复用于之前每一层里权值的微分计算。因此空间复杂度没有变化。同时也没有重复计算,每一个微分值都在之后的迭代中使用。

因此自然地,在深度学习中反向模式作为主要甚至唯一的自动微分计算手段而使用。

关于自动微分,Numerical Optimization | Jorge Nocedal | Springer 一书的第八章有非常精彩的讲解。事实上前向模式反向模式不仅可以用来计算微分,还可以计算Hessian。


作者:过拟合
链接:https://www.zhihu.com/question/27239198/answer/154510111

猜你喜欢

转载自blog.csdn.net/anneqiqi/article/details/70893075
今日推荐