文章目录
1. 概览
《Distilling the Knowledge in a Neural Network》 是一篇关于知识蒸馏(Knowledge Distillation)技术的重要论文,由 Hinton 等人于2015年提出。这篇论文详细介绍了如何将一个大型的、复杂的机器学习模型(教师模型)的知识转移到一个较小的模型(学生模型)中,从而使小模型能够在保留大部分性能的同时拥有更高的效率。
2. 主要思想
知识蒸馏的核心思想是利用教师模型的软标签(soft labels)来训练学生模型。这里的软标签(soft labels)是指教师模型对输入数据预测的概率分布,传统的硬标签(hard labels)是指真实的分类标签。
2.1 软标签(Soft Labels)
教师模型对输入样本的预测输出是一个概率分布,而不是单一的类别标签。这种概率分布包含了教师模型对于各个类别的置信度,比硬标签提供了更多关于类间关系的信息。
2.2 温度参数(Temperature Parameter)
为了使概率分布更加平滑,引入了一个温度参数 T T T,它可以放大或缩小教师模型输出的 softmax 函数的值。
q i = e z i T ∑ e z i T q_i = \frac{e^{\frac{z_i}T}}{\sum e^{\frac{z_i}T}} qi=∑eTzieTzi
- 较低的温度 T T T 更接近标准的 softmax 输出,当 T = 1 T=1 T=1 是就是标准的 softmax;
- 较高的温度 T T T 会使概率分布更加平滑,更容易被学生模型学习;
2.3 损失函数(Loss Function)
在知识蒸馏中,蒸馏的学生模型,既希望能学习到教师模型的概率分布情况(soft labels),又能预测偏向真实情况(hard labels),于是 loss 可以分成两项交叉熵之和:
l o s s = α H ( t e a c h e r ( x ) , s t u d e n t ( x ) ) + ( 1 − α ) H ( t a r g e t , s t u d e n t ( x ) ) loss = \alpha H(teacher(x),student(x)) + (1- \alpha) H(target,student(x)) loss=αH(teacher(x),student(x))+(1−α)H(target,student(x))
其中:
- H ( t e a c h e r ( x ) , s t u d e n t ( x ) ) H(teacher(x),student(x)) H(teacher(x),student(x)) 是教师模型与学生模型的交叉熵
- H ( t a r g e t , s t u d e n t ( x ) ) H(target,student(x)) H(target,student(x)) 是学生模型与真实情况的交叉熵
- α \alpha α 是一个超参数,用来平衡两个损失项的权重
3. 举例
例如:
- 教师模型输出的概率结果是: [ 0.1 , 0.4 , 0.5 ] [0.1, 0.4, 0.5] [0.1,0.4,0.5],即是第一项的概率为0.1,第二项的概率为0.4,第三项的概率为0.5
- 学生模型输出的概率结果是: [ 0.11 , 0.43 , 0.46 ] [0.11, 0.43, 0.46] [0.11,0.43,0.46],即是第一项的概率为0.11,第二项的概率为0.43,第三项的概率为0.46
- 真实的结果是: [ 0 , 0 , 1 ] [0,0,1] [0,0,1],即实际情况就是对应第三项
- 假设 α = 0.7 \alpha = 0.7 α=0.7
则:
l o s s = − 0.7 ∗ ( 0.1 ∗ l o g ( 0.11 ) + 0.4 ∗ l o g ( 0.43 ) + 0.5 ∗ l o g ( 0.46 ) ) − 0.3 ∗ l o g ( 0.46 ) loss = -0.7 * (0.1*log(0.11)+0.4*log(0.43)+0.5*log(0.46))-0.3*log(0.46) loss=−0.7∗(0.1∗log(0.11)+0.4∗log(0.43)+0.5∗log(0.46))−0.3∗log(0.46)
4. 应用场景
知识蒸馏技术广泛应用于各种场景,特别是在移动设备和边缘计算中,因为这些设备通常计算资源有限,需要高效的模型。此外,知识蒸馏也被应用于模型压缩、模型加速等领域。
5. 参考
https://arxiv.org/pdf/1503.02531
欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;
欢迎关注知乎/CSDN:SmallerFL
也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤