【标准化方法】(4) Weight Normalization 原理解析、代码复现,附Pytorch代码

今天和各位分享一下深度学习中常用的归一化方法,权重归一化(Weight Normalization, WN),通过理论解析,用 Pytorch 复现一下代码。

Weight Normalization 的论文地址如下:https://arxiv.org/pdf/1903.10520.pdf


1. 原理解析

权重归一化(Weight  Normalization,WN)选择对神经网络的权值向量 W 进行参数重写,参数化权重改善条件最优问题来加速收敛,灵感来自批归一化算法,但是并不像批归一化算法一样依赖于批次大小,不会对梯度增加噪声且计算量很小。权重归一化成功用于 LSTM 和对噪声敏感的模型,如强化学习和生成模型。

对深度学习网络权值 W 进行归一化的操作公式如下:

 

w = \frac{g}{||v||} v

通过一个 k 维标量 g 和一个向量 V 对权重向量 W 进行解耦合。标量 g=||W|| ,即权重 W 的大小,||v|| 表示 v 的欧几里得范数(二范数)。

作者提出对参数 v,g 直接重新参数化然后执行新的随机梯度下降,并且认为通过将权重向量(g)的范数与(\frac{v}{||v||})的方向解耦,加速了随机梯度下降的收敛

假设代价函数记为 L,此时的深度学习网络权值的梯度计算公式为:

\Delta_{_g}L=\Delta_{_w}L\cdot\Delta_{_g}W=\frac{\Delta_{_w}L\cdot\nu}{||\nu||}

M_w=I-\frac{ww'}{||w||^2},其中 M_w 是投影矩阵。梯度计算可以写成\Delta_{_v}L=\frac{g}{||v||}\cdot M_{_w}\Delta_{_w}L

\frac{||\Delta v||}{||v||} = c当梯度噪声大时,c 变大,有 \|v'\|=(\|v\|^2+c^2\|v\|^2)^{1/2}>\|v\|,则 \Delta_{v'}L 变小。

当梯度较小时,c 变小趋于0,有 \|v'\|=(\|v\|^2+c^2\|v\|^2)^{1/2} \approx \|v\|。即:权重归一化 WN 使用这种机制做到梯度稳定。另外,作者也发现 ||v|| 对学习率有很强的鲁棒性。

WN 不像 BN 还具有固定神经网络各层产生的特征尺度的好处,WN 需要小心的参数初始化给 v 的范数设定一个范围(正态分布均值为零,标准差为 0.05),这样虽然延长了参数更新的时间,但收敛后的测试性能会比较好。

t = \frac{v \cdot x}{||v||},仅在初始化期间取 g\leftarrow\frac{1}{\sigma[t]},b\leftarrow\frac{-\mu[t]}{\sigma[t]}

可以得到应用 WN 后,

\begin{aligned} & y=\phi(w\cdot x+b) \\ &=\phi(g\cdot{\frac{v}{||v||}}x+b) \\ &=\phi(\frac{1}{\sigma[t]}\cdot\frac{v}{||v||}x-\frac{\mu[t]}{\sigma[t]}) \\ &=\phi(\frac{t-\mu[t]}{\sigma[t]}) \end{aligned}

由上式可得,当 WN 进行参数初始化时可以在一开始达到和 BN 相同的作用。


2. 代码演示

这里以《Micro-Batch Training with Batch-Channel Normalization and Weight Standardization》这篇文章中的权重归一化方法为例,展示一下代码,比较简单,只需要对权重文件的每个通道做归一化处理。示意图如下。

import torch

def WS(weight:torch.Tensor, eps:float):
    # 权重shape=[c_out, c_in, k_h, k_w]
    c_out, c_in, *kernel_shape = weight.shape
    # [c_out, c_in, k_h, k_w]-->[c_out, c_in*k_h*k_w]
    weight = weight.view(c_out, -1)
    # 计算 [c_in*k_h*k_w] 维度上的均值和方差 --> [c_out,1]
    var, mean = torch.var_mean(weight, dim=1, keepdim=True)
    # 权重标准化
    weight = (weight-mean) / torch.sqrt(var+eps)
    # [c_out, c_in*k_h*k_w]-->[c_out, c_in, k_h, k_w]
    return weight.view(c_out, c_in, *kernel_shape)

猜你喜欢

转载自blog.csdn.net/dgvv4/article/details/130585048