权重衰减(weight decay)的理解及Tensorflow的实现
- 概要
- 公式解析
- 为什么会起作用
- Tensorflow的实现
1.概要:
权重衰减即L2正则化,目的是通过在Loss函数后加一个正则化项,通过使权重减小的方式,一定减少模型过拟合的问题。
2.公式解析:
L2正则化的公式如图;
其中 C0 是原来并没有使用L2正则化时的损失函数,比如交叉熵函数等;
后面的:
这一项是正则化项,即计算权重矩阵w的所有项的平方和÷2n,然后× λ(也叫正则化系数),作为最终Loss函数的一项参与梯度下降;这样的话我们在训练的反向传播过程中,得到的权重w就会尽可能小,从而一定程度上减小了模型的复杂度,从而一定程度上解决了过拟合问题。
3.为什么L2正则化会起作用:
从直观上讲,L2正则化使得训练的模型在兼顾最小化分类(或其他目标)的Loss的同时,使得权重w尽可能地小,从而将权重约束在一定范围内,减小了模型的复杂度;同时,如果将w约束在一定范围内,也能够有效防止梯度爆炸。
4.Tensorflow实现weight decay:
def add_weight_decay(self, weights, lambda=5e-3):
weight_decay = tf.multiply(
tf.nn.l2_loss(weights), lambda, name='weight_loss')
tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
weight_decay)
return weights
这里定义了一个add_weight_decay函数,使用了tf.nn.l2_loss函数,其中参数lambda就是我们的λ正则化系数;
使用时需要先传入权重变量:
w = tf.Variable(tf.truncated_normal(shape=[128, 1024]))
w = add_weight_decay(w)
这时我们已经将w的L2正则化Loss放入了REGULARIZATION_LOSSES这个集合里,再在Loss中加入正则化Loss即可:
regular_loss = tf.get_collection(
tf.GraphKeys.REGULARIZATION_LOSSES)
regular_loss = tf.add_n(regular_loss)
loss = loss + regular_loss