蒸馏法文章选读——Correlation Congruence for Knowledge Distillation

7, Correlation Congruence for Knowledge Distillation

https://arxiv.org/abs/1904.01802

1),创新点:原始的蒸馏法只是用学生网络的某个向量去拟合教师网络的该向量,无论是kl散度还是欧式距离,只是向量之间的映射;但是由于教师网络和学生网络本来的差异性,所以不应该仅仅学习教师网络和学生网络单个样本向量间差异,还应该学习这两个样本间的相关性,instance congruence (学术网络和教师网络预测值的kl散度,KL divergence on predictions of teacher and student) and correlation congruence (教师网络和学生网络的相关性之间的欧式距离 euclidean distance on correlation of teacher and student).

这里的f是网络的embedded feature space的一个点,即一个图像映射到embedded feature space的一个点,

又有一个映射关系如下,其中C是相关矩阵,

Cij表示的是xi和xj在embedding space的相关性

上面的fi函数可以用各种correlation关系来度量,引入下面的correlation congruence函数

扫描二维码关注公众号,回复: 11930133 查看本文章

2),目标损失:所以文章最终的CCDK目标损失为,第一项为真实标签的交叉熵损失,第二项为原始蒸馏法的kl散度损失,第三项为本文新的到的Lcc损失

整个算法的流程如下所示:

3),最后一步就是选择一个合适的F->C的映射

最终选择的是一个非线形RBF核,因为非线形核更加灵活 可以补货向量之间的非线形关系

对右边的公式进行泰勒级数展开

4),minibatch的采样策略

CUR类一致性随机采样:每个minibatch包含6个类别,每个类别取8张图,总共有48张

SUR超类一致性随机采样:使用教师模型对所有数据分k类,然后每个minibatch包含6个超级类别,在每个超级类别中取8张图,总共有48张

5) 实验结果分析

综合上面两个分别在reid和人脸识别的应用来看,用小网络去学习一个大的/有难度的训练集是很困难的,单纯的one-hot向量很难帮助网络学习到好的泛化能力。有两点结论:

1,小网络使用one-hot去直接学习大训练集(ms1m)/难的训练集(MSMT17),很困难。

学生:小网络

教师:大网络

2,小网络主要差在泛化能力上,表4其实在干扰项只有百、千量级的时候差别不明显,到十万量级的时候差别就非常大了,

batch取为40,所以k=1,意味着一个minibatch只有一个人的40张照片,当k=20,一个人只有两张;但是k=1时,由于一个人的图片数可能少于40,所以一个minibatch是很可能出现两个人的图片;综合来看,单个人图片越少越难训练,因为k-20结果是最差的,而且差距很大。

k=4优于k=1说明适当多的个体数是好的。

类内的差异如上图所示,cckd的类内整体颜色要比kd的颜色要深,所以一定程度上让类内更紧密。为什么会这样?再回到公式7,相比传统的蒸馏法,多了一项Lcc,再看公式6,学生网络特征间的距离还需要和教师网络特征间的距离更接近,假设教师网络类内学习的较好,那么cckd相比kd也会有更好的类内的差距。(注意这里严重依赖于教师网络类内学习的效果)

传统kd只需要教师网络的单个向量和学生网络的单个向量接近,cckd还需要学习教师网络的向量间和学生网络的向量间距离更接近,即类内的结构信息,比如教师网络认为某人带眼镜和不带眼镜相比这个人长发和短发差异更大,那学生网络也需要学习到该信息,并且这个差异是通过RBF的p阶泰勒展开式得到的,所以能学习到更多的向量本身每个维度代表的信息。

猜你喜欢

转载自blog.csdn.net/whatwho_518/article/details/89887305
今日推荐