【知识蒸馏】Masked Generative Distillation


[论文]:Yang Z, Li Z, Shao M, et al. Masked Generative Distillation[J]. arXiv preprint arXiv:2205.01529, 2022.
代码地址
论文地址
论文翻译


一、摘要

知识蒸馏已成功应用于各种任务。当前的蒸馏算法通常通过模仿教师的输出来提高学生的表现。本文表明,教师还可以通过指导学生的特征恢复来提高学生的表征能力。从这个角度来看,我们提出了掩蔽生成蒸馏(MGD),它很简单:我们屏蔽学生特征的随机像素,并迫使它通过一个简单的块生成教师的完整特征。MGD是一种真正通用的基于特征的蒸馏方法,可用于各种任务,包括图像分类、目标检测、语义分割和实例分割。我们在具有广泛数据集的不同模型上进行了实验,结果表明所有学生都取得了出色的改进。值得注意的是,我们将 ResNet-18 从 69.90% 提高到 71.69% ImageNet top-1 准确率,ResNet-50 主干的 RetinaNet 从 37.4 提高到 41.0 边界框 mAP,SOLO 基于 ResNet-50 从 33.1 提高到 36.2 Mask mAP,DeepLabV3 基于 ResNet-18 从 73.20 提高到 76.02 mIoU。

二、主要贡献

1.引入了一种新的基于特征的知识蒸馏方法,它使学生通过其掩码特征生成教师的特征,而不是直接模仿它。
2.提出了一种新的基于特征的蒸馏方法——掩蔽生成蒸馏,它简单且易于使用两个超参数。
3.我们通过对不同数据集的大量实验来验证我们的方法在各种模型上的有效性。对于图像分类和密集预测任务,学生使用 MGD 取得了显着的改进。

三、创新点灵感分析

之前的feature-based蒸馏方法通常会让学生模型尽可能模仿教师模型的输出,因为教师模型通常有着更强的表示能力。在本文中,作者发现直接去模仿教师模型来提升学生特征的表示能力其实是不必要的,如果让学生模型使用部分pixels来重建教师模型的全部特征,那么学生模型对这些使用到的pixels的表示能力也会得到提升。
在这里插入图片描述
上图为FPN输出的第一层要素的可视化。教师:RetinaNet-ResNeXt101。学生:RetinaNet-ResNet50。FGD是一种检测器的提取方法,它迫使学生模仿老师的特征。
由上图可以看出学生模型和教师模型的特征存在差异,同时教师模型的mAP也比学生模型高。采用SOTA的蒸馏方法进行蒸馏后(使用注意力来强迫学生模型模拟教师模型的特征),学生模型的特征与教师模型更相似,同时mAP也得到极大的提升。而使用本文的蒸馏方法训练后,学生模型与教师模型特征虽然相差较大,但是mAP甚至达到教师模型的水平。

四、总体框架

4.1 算法介绍

以前的基于特征的提取方法通常让学生尽可能地模仿老师的输出,因为老师的特征具有更强的表示能力。但是,我们认为没有必要直接模仿老师来提高学生特征的表征能力。用于提取的特征一般是通过深度网络的高阶语义信息。特征像素在一定程度上已经包含了相邻像素的信息。所以,如果能通过简单的分块,用部分像素还原老师的全部特征,这些用过的像素的表现力也能得到提升。
从这个角度出发,我们提出了一种简单有效的基于特征的提取方法——掩蔽生成提取法。如下图所示,我们首先屏蔽学生特征的随机像素,然后通过一个简单的块用屏蔽的特征生成教师的完整特征。由于在每次迭代中使用随机像素,因此在整个训练过程中将使用所有像素,这意味着该特征将更加鲁棒,并且其表示能力将得到提高。在我们的方法中,老师只是作为学生恢复特征的指导,并不要求学生直接模仿。
在这里插入图片描述

4.2 Generation with Masked Feature

对于基于 CNN 的模型,更深层的特征具有更大的感受野和更好地表示原始输入图像。换句话说,特征图像素已经在一定程度上包含了相邻像素的信息。因此,我们可以使用部分像素来恢复完整的特征图。
我们的方法旨在通过学生的掩码特征生成教师的特征,这有助于学生获得更好的表示。
我们用 T l ∈ R C × H × W T^l ∈ R^{C×H×W} TlRC×H×W S l ∈ R C × H × W ( l = 1 , . . , L ) S^l ∈ R^{C×H×W} (l = 1,.., L) SlRC×H×W(l=1,..,L)教师和学生的第 l l l个特征图。首先,我们设置第 l l l 个随机掩码来覆盖学生的第 l l l 个特征,可以表示为:
M i , j l = { 0 , if  R i , j l < λ   1  Otherwise (1) M^l_{i,j}= \begin{cases} 0, & \text {if $R^l_{i,j}<\lambda $ } \\ 1 & \text{ Otherwise} \end{cases} \tag {1} Mi,jl={ 0,1if Ri,jl<λ  Otherwise(1)其中 R i , j l R^l_{i,j} Ri,jl是 (0, 1) 中的随机数,i, j 分别是特征图的水平坐标和垂直坐标。λ 是一个超参数,表示掩码比率。第 l l l 个特征图由第 l l l个随机掩码覆盖。
对应的代码如下,self.lambda_mgd代表masked ratio. Defaults to 0.65,mat代表生成的随机掩码覆盖:

device = preds_S.device
mat = torch.rand((N,1,H,W)).to(device)
mat = torch.where(mat>1-self.lambda_mgd, 0, 1).to(device)

然后我们使用相应的掩码来覆盖学生的特征图,并尝试生成具有左像素的教师特征图,可以表述如下: G ( f a l i g n ( S l ) ⋅ M l ) ⟶ T l (2) G(f_{align}(S^l)\cdot M^l)\longrightarrow T^l\tag {2} G(falign(Sl)Ml)Tl(2) G ( F ) = W l 2 ( R e L U ( W l 1 ( F ) ) ) (3) G(F)=W_{l2}(ReLU(W_{l1}(F))) \tag {3} G(F)=Wl2(ReLU(Wl1(F)))(3) G G G表示包含两个卷积层的投影仪层: W l 1 W_{l1} Wl1 W l 2 W_{l2} Wl2,一个激活层 ReLU。在本文中,我们采用了 1×1 的卷积层对于适配层 f a l i g n f_{align} falign,投影仪层 W l 1 W_{l1} Wl1 W l 2 W_{l2} Wl2的3×3卷积层。用于将覆盖后的学生网络生成生成教师的feature_maps

公式2的代码为将学生网络特征与生成的随机掩码覆盖相乘,最终能得到覆盖后的特征:

masked_fea = torch.mul(preds_S, mat)

之后由公式3将新生成的masked_fea 进一步处理,尝试生成教师的feature_maps,对应的代码如下:

self.generation = nn.Sequential(
     nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1),
     nn.ReLU(inplace=True), 
     nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1))
new_fea = self.generation(masked_fea)

根据这种方法,我们设计了MGD的蒸馏损失 L d i s L_{dis} Ldis:
L d i s ( S , T ) = ∑ l = 1 L ∑ k = 1 C ∑ i = 1 H ∑ j = 1 W ( T k , i , j l − G ( f a l i g n ( S k , i , j l ) M i , j l ) ) 2 (4) L_{dis}(S,T)=\sum\limits_{l=1}^L\sum\limits_{k=1}^C\sum\limits_{i=1}^H\sum\limits_{j=1}^W(T^l_{k,i,j}-G(f_{align}(S^l_{k,i,j})M^l_{i,j}))^2\tag {4} Ldis(S,T)=l=1Lk=1Ci=1Hj=1W(Tk,i,jlG(falign(Sk,i,jl)Mi,jl))2(4)其中 L 是蒸馏层的总和,C、H、W 表示特征图的形状。S 和 T 分别表示学生和教师的特征。对应的代码如下:

dis_loss = loss_mse(new_fea, preds_T)/N

这里值得注意的是,本文仅需要两个超参数。分别为:掩码率 λ \lambda λ、loss平衡参数 α \alpha α,相比于其他的蒸馏算法调参更为简单。

五、总结

以前基于特征的提炼方法通常会让学生尽可能地模仿老师的输出,因为老师的特征具有更强的代表性。然而,作者认为没有必要直接模仿老师来提高学生特征的表示力。用于提炼的特征一般是通过深度网络的高阶语义信息。特征像素在一定程度上已经包含了相邻像素的信息。因此,如果可以通过简单的区块来使用部分像素来还原老师的完整特征,那么这些被使用的像素的表示力也可以得到提高。
通过这个掩膜,获得了部分的特征图,然后再生成新的特征图去模仿教师网络的特征图,相比原始的特征模仿,多的这一步,是增大网络学习的难度,从而迫使学生网络去学习一个更优秀的特征表示,而生成的特征图去模仿教师网络是因为教师网络的特征表示更优秀,通过模仿可以让学生网络训练时候的”进步“方向不走偏,往学习更优秀的特征表示的方向走。

猜你喜欢

转载自blog.csdn.net/AaaA00000001/article/details/128521196
今日推荐