使用二阶梯度作正则项交叉训练参数

使用二阶梯度作正则项交叉训练参数

在上周五讨论时关于交叉训练"语义概念参数"和"视觉概念参数"时我们说到了导致正确率底下的两个缺陷:

  • 训练样本少:这个即将解决,因为我们上次讨论将进行关于"指令"的概念训练,这比需要逻辑推理的任务更加简单,而也可以生成更多的样本;
  • 训练方式的问题。现在我引入一个有效的正则项:这是在凸优化的数学原理上来优化,这是本文要谈的;

值得注意的是,这个方法(用二阶梯度作正则项)不止对我们当前这个任务适用,我更感觉这是一个通用的适用于损失是凸函数的方法;


再谈模型的损失

注意我们上次说到,模型的预测是one-hot形式的空间概念指代:[上下,左右,左上右下… …],那么最后一层适用softmax交叉熵损失来做,那么损失是:

L ( A s ) = y ^ l o g ( f s o f t m a x ( A s C T ) ) \mathcal{L}(A_s) = -\hat{y} \odot log(f_{softmax}(A_s \odot C^{*T}) )

另一边,我们说到, A s A_s 也在另一个网络中发挥权重参数的作用,我们令来自那个网络的损失为 L ^ ( A s ) \hat{\mathcal{L}}(A_s) ;下面我们将证明,这是一个凸函数。


softmax是凸函数

这个证明也可以跳过,只需要记住softmax是凸函数这个结论也行,证明如下,softmax的交叉熵损失是:

L ( w 1 , w 2 ,   , w k ) = 1 m [ i = 1 m j = 1 k 1 { y ( i ) = j } log e w j T x ( i ) l = 1 k e w l T x ( i ) ] \mathcal{L}(w_1, w_2,\cdots, w_k)=-\frac1m\left[\sum_{i=1}^m\sum_{j=1}^k 1\{y^{(i)}=j\}\log\frac{e^{w^T_jx^{(i)}}}{\sum_{l=1}^ke^{w^T_lx^{(i)}}}\right]

现在令 :
a j = e w j T x l = 1 k e w l T x a_j=\frac{e^{w^T_jx}}{\sum_{l=1}^ke^{w^T_lx}}

分情况,当 n j n\neq j
w n a j = e w j T x e w n T x ( l = 1 k e w l T x ) 2 x = a j a n x \nabla_{w_n} a_j=-\frac{e^{w_j^Tx}e^{w_n^Tx}}{\left(\sum_{l=1}^ke^{w_l^Tx}\right)^2}x=-a_ja_nx

n = j n= j
w j a j = ( e w j T x l = 1 k e w l T x e w j T x l = 1 k e w l T x e w j T x l = 1 k e w l T x ) x = a j ( 1 a j ) x \nabla_{w_j} a_j=\left(\frac{e^{w_j^Tx}}{\sum_{l=1}^ke^{w_l^Tx}}-\frac{e^{w_j^Tx}}{\sum_{l=1}^ke^{w_l^Tx}}\frac{e^{w_j^Tx}}{\sum_{l=1}^ke^{w_l^Tx}}\right)x=a_j(1-a_j)x

所以有:
w n C = 1 m i = 1 m ( j n 1 { y ( i ) = j } a j a n / a j x ( i ) + 1 { y ( i ) = n } a n ( 1 a n ) / a n x ( i ) ) = 1 m i = 1 m [ x ( i ) ( 1 { y ( i ) = n } a n ) ] \nabla_{w_n}C=-\frac1m\sum_{i=1}^m\left(\sum_{j\neq n}-1\{y^{(i)}=j\}a_ja_n/a_jx^{(i)}+1\{y^{(i)}=n\}a_n(1-a_n)/a_nx^{(i)}\right) =-\frac1m\sum_{i=1}^m\left[x^{(i)}(1\{y^{(i)}=n\}-a_n)\right]

注意 1 { y ( i ) = j } 1\{y^{(i)}=j\} 只有当 y ( i ) = j y^{(i)}=j 时为一,那么下式为半正定矩阵,因而对于softmax而言,交叉熵为凸函数:
w n 2 C = 1 m i = 1 m w n [ x ( i ) ( 1 { y ( i ) = n } a n ) ] = 1 m i = 1 m a n ( 1 a n ) x ( i ) x ( i ) T \nabla_{w_n}^2C=-\frac1m\sum_{i=1}^m\nabla_{w_n}\left[x^{(i)}(1\{y^{(i)}=n\}-a_n)\right]=\frac1m\sum_{i=1}^ma_n(1-a_n)x^{(i)}x^{(i)T}


L ^ ( A + α Δ A ) \hat{\mathcal{L}}(A+ \alpha \Delta A) 的一个上界的证明

注意参数 A A 的更新方式采取最简单的SGD:
A t + 1 A t + α A L ^ A^{t+1} \leftarrow A^t+\alpha \nabla_A \hat{\mathcal{L}}

证明:当 A 2 L ^ M I \nabla^2_A \hat{\mathcal{L}} \leq MI 时:
L ^ ( A + α Δ A ) L ^ ( A ) + γ A L ^ 2 \hat{\mathcal{L}}(A+ \alpha \Delta A) \leq \hat{\mathcal{L}}(A) + \gamma||\nabla_A \hat{\mathcal{L}}||^2

proof:

首先易知 A L ^ ( A ) = Δ A -\nabla_A \hat{\mathcal{L}}(A) = \Delta A ;现在我们对 L ^ ( A + α Δ A ) \hat{\mathcal{L}}(A+ \alpha \Delta A) 作Taylor展开:
L ^ ( A + α Δ A ) = L ^ ( A ) + α A L ^ ( A ) Δ A + A 2 L ^ Δ A 2 α 2 / 2 \hat{\mathcal{L}}(A+ \alpha \Delta A) = \hat{\mathcal{L}}(A)+ \alpha \nabla_A \hat{\mathcal{L}}(A) \odot \Delta A +\nabla^2_A \hat{\mathcal{L}} ||\Delta A||^2 \alpha^2 /2
L ^ ( A ) + α A L ^ ( A L ^ ) + M Δ A 2 α 2 / 2 \le \hat{\mathcal{L}}(A)+ \alpha \nabla_A \hat{\mathcal{L}} \odot (-\nabla_A \hat{\mathcal{L}}) +M||\Delta A||^2 \alpha^2 /2
= L ^ ( A ) + ( α 2 M / 2 α ) A L ^ 2 = \hat{\mathcal{L}}(A)+(\alpha^2M /2 - \alpha )||\nabla_A \hat{\mathcal{L}}||^2

现在令 γ = α 2 M / 2 α 0 \gamma = \alpha^2M /2 - \alpha \le 0 即可满足:
L ^ ( A + α Δ A ) L ^ ( A + α Δ A ) + ( α α 2 M / 2 ) A L ^ 2 L ^ ( A ) \hat{\mathcal{L}}(A+ \alpha \Delta A) \le \hat{\mathcal{L}}(A+ \alpha \Delta A) +(\alpha-\alpha^2M /2 )||\nabla_A \hat{\mathcal{L}}||^2 \le \hat{\mathcal{L}}(A)

得证 \Box


最终的正则项

现在我们已知,只需要 A 2 L ^ M I \nabla^2_A \hat{\mathcal{L}} \leq MI ,就可以使得每一步梯度更新保证减小目标函数;我们的正则项是 λ M I 2 A 2 L ^ \lambda \frac{M||I||^2}{\nabla^2_A \hat{\mathcal{L}}} ,最终的损失是:

L ( A s ) = y ^ l o g ( f s o f t m a x ( A s C T ) ) + λ M I 2 A s 2 L ( A s ) ^ \mathcal{L}(A_s) = -\hat{y} \odot log(f_{softmax}(A_s \odot C^{*T}) )+\lambda \frac{M||I||^2}{\nabla^2_{A_s} \hat{\mathcal{L(A_s)}}}

还没有验证实验效果,本周马上会使用这个损失来训练概念学习模型。

发布了142 篇原创文章 · 获赞 71 · 访问量 23万+

猜你喜欢

转载自blog.csdn.net/hanss2/article/details/84544438