深度学习的知识蒸馏:Distilling the Knowledge in a Neural Network


1. 概览

《Distilling the Knowledge in a Neural Network》 是一篇关于知识蒸馏(Knowledge Distillation)技术的重要论文,由 Hinton 等人于2015年提出。这篇论文详细介绍了如何将一个大型的、复杂的机器学习模型(教师模型)的知识转移到一个较小的模型(学生模型)中,从而使小模型能够在保留大部分性能的同时拥有更高的效率。

2. 主要思想

知识蒸馏的核心思想是利用教师模型的软标签(soft labels)来训练学生模型。这里的软标签(soft labels)是指教师模型对输入数据预测的概率分布,传统的硬标签(hard labels)是指真实的分类标签。
在这里插入图片描述

2.1 软标签(Soft Labels)

教师模型对输入样本的预测输出是一个概率分布,而不是单一的类别标签。这种概率分布包含了教师模型对于各个类别的置信度,比硬标签提供了更多关于类间关系的信息。

2.2 温度参数(Temperature Parameter)

为了使概率分布更加平滑,引入了一个温度参数 T T T,它可以放大或缩小教师模型输出的 softmax 函数的值。
q i = e z i T ∑ e z i T q_i = \frac{e^{\frac{z_i}T}}{\sum e^{\frac{z_i}T}} qi=eTzieTzi

  • 较低的温度 T T T 更接近标准的 softmax 输出,当 T = 1 T=1 T=1 是就是标准的 softmax;
  • 较高的温度 T T T 会使概率分布更加平滑,更容易被学生模型学习;

在这里插入图片描述

2.3 损失函数(Loss Function)

在知识蒸馏中,蒸馏的学生模型,既希望能学习到教师模型的概率分布情况(soft labels),又能预测偏向真实情况(hard labels),于是 loss 可以分成两项交叉熵之和:
l o s s = α H ( t e a c h e r ( x ) , s t u d e n t ( x ) ) + ( 1 − α ) H ( t a r g e t , s t u d e n t ( x ) ) loss = \alpha H(teacher(x),student(x)) + (1- \alpha) H(target,student(x)) loss=αH(teacher(x),student(x))+(1α)H(target,student(x))

其中:

  • H ( t e a c h e r ( x ) , s t u d e n t ( x ) ) H(teacher(x),student(x)) H(teacher(x),student(x)) 是教师模型与学生模型的交叉熵
  • H ( t a r g e t , s t u d e n t ( x ) ) H(target,student(x)) H(target,student(x)) 是学生模型与真实情况的交叉熵
  • α \alpha α 是一个超参数,用来平衡两个损失项的权重

3. 举例

例如:

  • 教师模型输出的概率结果是: [ 0.1 , 0.4 , 0.5 ] [0.1, 0.4, 0.5] [0.1,0.4,0.5],即是第一项的概率为0.1,第二项的概率为0.4,第三项的概率为0.5
  • 学生模型输出的概率结果是: [ 0.11 , 0.43 , 0.46 ] [0.11, 0.43, 0.46] [0.11,0.43,0.46],即是第一项的概率为0.11,第二项的概率为0.43,第三项的概率为0.46
  • 真实的结果是: [ 0 , 0 , 1 ] [0,0,1] [0,0,1],即实际情况就是对应第三项
  • 假设 α = 0.7 \alpha = 0.7 α=0.7

则:
l o s s = − 0.7 ∗ ( 0.1 ∗ l o g ( 0.11 ) + 0.4 ∗ l o g ( 0.43 ) + 0.5 ∗ l o g ( 0.46 ) ) − 0.3 ∗ l o g ( 0.46 ) loss = -0.7 * (0.1*log(0.11)+0.4*log(0.43)+0.5*log(0.46))-0.3*log(0.46) loss=0.7(0.1log(0.11)+0.4log(0.43)+0.5log(0.46))0.3log(0.46)

4. 应用场景

知识蒸馏技术广泛应用于各种场景,特别是在移动设备和边缘计算中,因为这些设备通常计算资源有限,需要高效的模型。此外,知识蒸馏也被应用于模型压缩、模型加速等领域。

5. 参考

https://arxiv.org/pdf/1503.02531


欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_36803941/article/details/143334527