深度学习 | 训练网络trick——知识蒸馏

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qiu931110/article/details/88085540

1.原理介绍

知识蒸馏论文地址

Hinton的文章"Distilling the Knowledge in a Neural Network"首次提出了知识蒸馏的概念,通过引入教师网络用以诱导学生网络的训练,实现知识迁移。所以其本质上和迁移学习有点像,但实现方式是不一样的。用“蒸馏”这个词来形容这个过程是相当形象的。用下图来解释这个过程。

教师网络:大规模,参数量大的复杂网络模型。难以应用到设备端的模型。
学生网络:小规模,参数量小的精简网络模型。可应用到设备端的模型,俗称可落地模型。

我们可以认为教师网络是一个混合物,网络复杂的结构就是杂质,是我们不需要用到的东西,而网络学到的概率分布就是精华,是我们需要的。如上图所示,对于教师网络的蒸馏过程,我们可以形象的认为是通过温度系数T,将复杂网络结构中的概率分布蒸馏出来,并用该概率分布来指导精简网络进行训练。整个通过温度系数T的蒸馏过程由如下公式实现:

从上述公式中可以看出,T值越大,概率分布越软(很多网上的博客都这么说)。其实我个人认为这就是在迁移学习的过程中添加了扰动,从而使得精简网络在借鉴学习的时候更有效,泛化能力更强,这其实就是一种抑制过拟合的策略,和其他抑制过拟合策略在原理上是一致的。

2.蒸馏后学习策略

在第一部分中我们介绍了蒸馏的整个过程,那么在蒸馏结束后,精简网络就要开始跟着负责网络的概率分布进行学习了,在这个过程中是使用KL散度来监督这个学习过程的。接下来简单介绍下KL散度的原理。


上述公式为KL散度的定义式,我们最终的学习目标是学生网络能够学习到教师网络的概率分布,也就是两者的概率分布能够尽可能的相同。而根据KL散度的原理为T_Prob和S_Prob越接近,KL散度值越小。基于KL散度的这个原理,我们才可以利用这个指标来作为损失函数的计算策略。

3.代码介绍

知识蒸馏代码github传送门

由于知识蒸馏策略是基于SoftmaxLoss的,因此我们利用caffe实现时,只需要在SoftmaxLoss的基础上,添加一个教室网络,温度系数以及KL散度的计算即可。
(一)在头文件中添加教师网络的定义

(二)在头文件中添加温度系数的定义

(三)在C文件中添加KL散度计算策略

猜你喜欢

转载自blog.csdn.net/qiu931110/article/details/88085540
今日推荐