权重衰减(weight decay)的理解及Tensorflow的实现

权重衰减(weight decay)的理解及Tensorflow的实现

在这里插入图片描述

  1. 概要
  2. 公式解析
  3. 为什么会起作用
  4. Tensorflow的实现

1.概要:

权重衰减即L2正则化,目的是通过在Loss函数后加一个正则化项,通过使权重减小的方式,一定减少模型过拟合的问题。

2.公式解析:

在这里插入图片描述
L2正则化的公式如图;

其中 C0 是原来并没有使用L2正则化时的损失函数,比如交叉熵函数等;

后面的:
在这里插入图片描述
这一项是正则化项,即计算权重矩阵w的所有项的平方和÷2n,然后× λ(也叫正则化系数),作为最终Loss函数的一项参与梯度下降;这样的话我们在训练的反向传播过程中,得到的权重w就会尽可能小,从而一定程度上减小了模型的复杂度,从而一定程度上解决了过拟合问题。

3.为什么L2正则化会起作用:

从直观上讲,L2正则化使得训练的模型在兼顾最小化分类(或其他目标)的Loss的同时,使得权重w尽可能地小,从而将权重约束在一定范围内,减小了模型的复杂度;同时,如果将w约束在一定范围内,也能够有效防止梯度爆炸。

4.Tensorflow实现weight decay:

def add_weight_decay(self, weights, lambda=5e-3):

    weight_decay = tf.multiply(
        tf.nn.l2_loss(weights), lambda, name='weight_loss')
        
    tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,
                         weight_decay)
    return weights

这里定义了一个add_weight_decay函数,使用了tf.nn.l2_loss函数,其中参数lambda就是我们的λ正则化系数;

使用时需要先传入权重变量:

w = tf.Variable(tf.truncated_normal(shape=[128, 1024]))
w = add_weight_decay(w)

这时我们已经将w的L2正则化Loss放入了REGULARIZATION_LOSSES这个集合里,再在Loss中加入正则化Loss即可:

regular_loss = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)

regular_loss = tf.add_n(regular_loss)

loss = loss + regular_loss
发布了58 篇原创文章 · 获赞 117 · 访问量 6821

猜你喜欢

转载自blog.csdn.net/weixin_44936889/article/details/103705435