There are two types of knowledge distillation: one is logits distillation, and the other is feature distillation. Logits distillation refers to using a higher temperature coefficient in softmax to enhance the information of negative labels, and then using the KL divergence of logits of Student and Teacher under high temperature softmax as loss. Intermediate feature distillation is to force the Student to learn the features of some intermediate layers of the Teacher, directly matching the intermediate features or learning the conversion relationship between the features. For example, between features No.1 and No.2, knowledge can be expressed as how to model the transformation between the two, and a matrix can be used to let the learner generate this matrix, the learning relationship between the learner and the transformation.
This article summarizes commonly used knowledge distillation papers and codes to facilitate subsequent learning and research.
1、Logits
Paper link: https://proceedings.neurips.cc/paper/2014/file/ea8fcd92d59581717e06eb187f10666d-Paper.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class Logits(nn.Module):
'''
Do Deep Nets Really Need to be Deep?
http://papers.nips.cc/paper/5484-do-deep-nets-really-need-to-be-deep.pdf
'''
def __init__(self):
super(Logits, self).__init__()
def forward(self, out_s, out_t):
loss = F.mse_loss(out_s, out_t)
return loss
2、ST
Paper link: https://arxiv.org/pdf/1503.02531.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftTarget(nn.Module):
'''
Distilling the Knowledge in a Neural Network
https://arxiv.org/pdf/1503.02531.pdf
'''
def __init__(self, T):
super(SoftTarget, self).__init__()
self.T = T
def forward(self, out_s, out_t):
loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
F.softmax(out_t/self.T, dim=1),
reduction='batchmean') * self.T * self.T
return loss
3、AT
Paper link: https://arxiv.org/pdf/1612.03928.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
AT with sum of absolute values with power p
'''
class AT(nn.Module):
'''
Paying More Attention to Attention: Improving the Performance of Convolutional
Neural Netkworks wia Attention Transfer
https://arxiv.org/pdf/1612.03928.pdf
'''
def __init__(self, p):
super(AT, self).__init__()
self.p = p
def forward(self, fm_s, fm_t):
loss = F.mse_loss(self.attention_map(fm_s), self.attention_map(fm_t))
return loss
def attention_map(self, fm, eps=1e-6):
am = torch.pow(torch.abs(fm), self.p)
am = torch.sum(am, dim=1, keepdim=True)
norm = torch.norm(am, dim=(2,3), keepdim=True)
am = torch.div(am, norm+eps)
return am
4、Fitness
Paper link: https://arxiv.org/pdf/1412.6550.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class Hint(nn.Module):
'''
FitNets: Hints for Thin Deep Nets
https://arxiv.org/pdf/1412.6550.pdf
'''
def __init__(self):
super(Hint, self).__init__()
def forward(self, fm_s, fm_t):
loss = F.mse_loss(fm_s, fm_t)
return loss
5、NST
Paper link: https://arxiv.org/pdf/1707.01219.pdf
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
NST with Polynomial Kernel, where d=2 and c=0
'''
class NST(nn.Module):
'''
Like What You Like: Knowledge Distill via Neuron Selectivity Transfer
https://arxiv.org/pdf/1707.01219.pdf
'''
def __init__(self):
super(NST, self).__init__()
def forward(self, fm_s, fm_t):
fm_s = fm_s.view(fm_s.size(0), fm_s.size(1), -1)
fm_s = F.normalize(fm_s, dim=2)
fm_t = fm_t.view(fm_t.size(0), fm_t.size(1), -1)
fm_t = F.normalize(fm_t, dim=2)
loss = self.poly_kernel(fm_t, fm_t).mean() \
+ self.poly_kernel(fm_s, fm_s).mean() \
- 2 * self.poly_kernel(fm_s, fm_t).mean()
return loss
def poly_kernel(self, fm1, fm2):
fm1 = fm1.unsqueeze(1)
fm2 = fm2.unsqueeze(2)
out = (fm1 * fm2).sum(-1).pow(2)
return out
6、PKT
Paper link: http://openaccess.thecvf.com/content_ECCV_2018/papers/Nikolaos_Passalis_Learning_Deep_Representations_ECCV_2018_paper.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
Adopted from https://github.com/passalis/probabilistic_kt/blob/master/nn/pkt.py
'''
class PKTCosSim(nn.Module):
'''
Learning Deep Representations with Probabilistic Knowledge Transfer
http://openaccess.thecvf.com/content_ECCV_2018/papers/Nikolaos_Passalis_Learning_Deep_Representations_ECCV_2018_paper.pdf
'''
def __init__(self):
super(PKTCosSim, self).__init__()
def forward(self, feat_s, feat_t, eps=1e-6):
# Normalize each vector by its norm
feat_s_norm = torch.sqrt(torch.sum(feat_s ** 2, dim=1, keepdim=True))
feat_s = feat_s / (feat_s_norm + eps)
feat_s[feat_s != feat_s] = 0
feat_t_norm = torch.sqrt(torch.sum(feat_t ** 2, dim=1, keepdim=True))
feat_t = feat_t / (feat_t_norm + eps)
feat_t[feat_t != feat_t] = 0
# Calculate the cosine similarity
feat_s_cos_sim = torch.mm(feat_s, feat_s.transpose(0, 1))
feat_t_cos_sim = torch.mm(feat_t, feat_t.transpose(0, 1))
# Scale cosine similarity to [0,1]
feat_s_cos_sim = (feat_s_cos_sim + 1.0) / 2.0
feat_t_cos_sim = (feat_t_cos_sim + 1.0) / 2.0
# Transform them into probabilities
feat_s_cond_prob = feat_s_cos_sim / torch.sum(feat_s_cos_sim, dim=1, keepdim=True)
feat_t_cond_prob = feat_t_cos_sim / torch.sum(feat_t_cos_sim, dim=1, keepdim=True)
# Calculate the KL-divergence
loss = torch.mean(feat_t_cond_prob * torch.log((feat_t_cond_prob + eps) / (feat_s_cond_prob + eps)))
return loss
7、FSP
Paper link: http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class FSP(nn.Module):
'''
A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
'''
def __init__(self):
super(FSP, self).__init__()
def forward(self, fm_s1, fm_s2, fm_t1, fm_t2):
loss = F.mse_loss(self.fsp_matrix(fm_s1,fm_s2), self.fsp_matrix(fm_t1,fm_t2))
return loss
def fsp_matrix(self, fm1, fm2):
if fm1.size(2) > fm2.size(2):
fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))
fm1 = fm1.view(fm1.size(0), fm1.size(1), -1)
fm2 = fm2.view(fm2.size(0), fm2.size(1), -1).transpose(1,2)
fsp = torch.bmm(fm1, fm2) / fm1.size(2)
return fsp
8、FT
Paper link: http://papers.nips.cc/paper/7541-paraphrasing-complex-network-network-compression-via-factor-transfer.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class FT(nn.Module):
'''
araphrasing Complex Network: Network Compression via Factor Transfer
http://papers.nips.cc/paper/7541-paraphrasing-complex-network-network-compression-via-factor-transfer.pdf
'''
def __init__(self):
super(FT, self).__init__()
def forward(self, factor_s, factor_t):
loss = F.l1_loss(self.normalize(factor_s), self.normalize(factor_t))
return loss
def normalize(self, factor):
norm_factor = F.normalize(factor.view(factor.size(0),-1))
return norm_factor
9、RKD
Paper link: https://arxiv.org/pdf/1904.05068.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
From https://github.com/lenscloth/RKD/blob/master/metric/loss.py
'''
class RKD(nn.Module):
'''
Relational Knowledge Distillation
https://arxiv.org/pdf/1904.05068.pdf
'''
def __init__(self, w_dist, w_angle):
super(RKD, self).__init__()
self.w_dist = w_dist
self.w_angle = w_angle
def forward(self, feat_s, feat_t):
loss = self.w_dist * self.rkd_dist(feat_s, feat_t) + \
self.w_angle * self.rkd_angle(feat_s, feat_t)
return loss
def rkd_dist(self, feat_s, feat_t):
feat_t_dist = self.pdist(feat_t, squared=False)
mean_feat_t_dist = feat_t_dist[feat_t_dist>0].mean()
feat_t_dist = feat_t_dist / mean_feat_t_dist
feat_s_dist = self.pdist(feat_s, squared=False)
mean_feat_s_dist = feat_s_dist[feat_s_dist>0].mean()
feat_s_dist = feat_s_dist / mean_feat_s_dist
loss = F.smooth_l1_loss(feat_s_dist, feat_t_dist)
return loss
def rkd_angle(self, feat_s, feat_t):
# N x C --> N x N x C
feat_t_vd = (feat_t.unsqueeze(0) - feat_t.unsqueeze(1))
norm_feat_t_vd = F.normalize(feat_t_vd, p=2, dim=2)
feat_t_angle = torch.bmm(norm_feat_t_vd, norm_feat_t_vd.transpose(1, 2)).view(-1)
feat_s_vd = (feat_s.unsqueeze(0) - feat_s.unsqueeze(1))
norm_feat_s_vd = F.normalize(feat_s_vd, p=2, dim=2)
feat_s_angle = torch.bmm(norm_feat_s_vd, norm_feat_s_vd.transpose(1, 2)).view(-1)
loss = F.smooth_l1_loss(feat_s_angle, feat_t_angle)
return loss
def pdist(self, feat, squared=False, eps=1e-12):
feat_square = feat.pow(2).sum(dim=1)
feat_prod = torch.mm(feat, feat.t())
feat_dist = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)
if not squared:
feat_dist = feat_dist.sqrt()
feat_dist = feat_dist.clone()
feat_dist[range(len(feat)), range(len(feat))] = 0
return feat_dist
10、AB
Paper link: https://arxiv.org/pdf/1811.03233.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class AB(nn.Module):
'''
Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons
https://arxiv.org/pdf/1811.03233.pdf
'''
def __init__(self, margin):
super(AB, self).__init__()
self.margin = margin
def forward(self, fm_s, fm_t):
# fm befor activation
loss = ((fm_s + self.margin).pow(2) * ((fm_s > -self.margin) & (fm_t <= 0)).float() +
(fm_s - self.margin).pow(2) * ((fm_s <= self.margin) & (fm_t > 0)).float())
loss = loss.mean()
return loss
11、SP
Paper link: https://arxiv.org/pdf/1907.09682.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class SP(nn.Module):
'''
Similarity-Preserving Knowledge Distillation
https://arxiv.org/pdf/1907.09682.pdf
'''
def __init__(self):
super(SP, self).__init__()
def forward(self, fm_s, fm_t):
fm_s = fm_s.view(fm_s.size(0), -1)
G_s = torch.mm(fm_s, fm_s.t())
norm_G_s = F.normalize(G_s, p=2, dim=1)
fm_t = fm_t.view(fm_t.size(0), -1)
G_t = torch.mm(fm_t, fm_t.t())
norm_G_t = F.normalize(G_t, p=2, dim=1)
loss = F.mse_loss(norm_G_s, norm_G_t)
return loss
12、Sobolev
Paper link: https://arxiv.org/pdf/1706.04859.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
class Sobolev(nn.Module):
'''
Sobolev Training for Neural Networks
https://arxiv.org/pdf/1706.04859.pdf
Knowledge Transfer with Jacobian Matching
http://de.arxiv.org/pdf/1803.00443
'''
def __init__(self):
super(Sobolev, self).__init__()
def forward(self, out_s, out_t, img, target):
target_out_s = torch.gather(out_s, 1, target.view(-1, 1))
grad_s = grad(outputs=target_out_s, inputs=img,
grad_outputs=torch.ones_like(target_out_s),
create_graph=True, retain_graph=True, only_inputs=True)[0]
norm_grad_s = F.normalize(grad_s.view(grad_s.size(0), -1), p=2, dim=1)
target_out_t = torch.gather(out_t, 1, target.view(-1, 1))
grad_t = grad(outputs=target_out_t, inputs=img,
grad_outputs=torch.ones_like(target_out_t),
create_graph=True, retain_graph=True, only_inputs=True)[0]
norm_grad_t = F.normalize(grad_t.view(grad_t.size(0), -1), p=2, dim=1)
loss = F.mse_loss(norm_grad_s, norm_grad_t.detach())
return loss
13、BSS
Paper link: https://arxiv.org/pdf/1805.05532.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.gradcheck import zero_gradients
'''
Modified by https://github.com/bhheo/BSS_distillation
'''
def reduce_sum(x, keepdim=True):
for d in reversed(range(1, x.dim())):
x = x.sum(d, keepdim=keepdim)
return x
def l2_norm(x, keepdim=True):
norm = reduce_sum(x*x, keepdim=keepdim)
return norm.sqrt()
class BSS(nn.Module):
'''
Knowledge Distillation with Adversarial Samples Supporting Decision Boundary
https://arxiv.org/pdf/1805.05532.pdf
'''
def __init__(self, T):
super(BSS, self).__init__()
self.T = T
def forward(self, attacked_out_s, attacked_out_t):
loss = F.kl_div(F.log_softmax(attacked_out_s/self.T, dim=1),
F.softmax(attacked_out_t/self.T, dim=1),
reduction='batchmean') #* self.T * self.T
return loss
class BSSAttacker():
def __init__(self, step_alpha, num_steps, eps=1e-4):
self.step_alpha = step_alpha
self.num_steps = num_steps
self.eps = eps
def attack(self, model, img, target, attack_class):
img = img.detach().requires_grad_(True)
step = 0
while step < self.num_steps:
zero_gradients(img)
_, _, _, _, _, output = model(img)
score = F.softmax(output, dim=1)
score_target = score.gather(1, target.unsqueeze(1))
score_attack_class = score.gather(1, attack_class.unsqueeze(1))
loss = (score_attack_class - score_target).sum()
loss.backward()
step_alpha = self.step_alpha * (target == output.max(1)[1]).float()
step_alpha = step_alpha.unsqueeze(1).unsqueeze(1).unsqueeze(1)
if step_alpha.sum() == 0:
break
pert = (score_target - score_attack_class).unsqueeze(1).unsqueeze(1)
norm_pert = step_alpha * (pert + self.eps) * img.grad / l2_norm(img.grad)
step_adv = img + norm_pert
step_adv = torch.clamp(step_adv, -2.5, 2.5)
img.data = step_adv.data
step += 1
return img
14、CC
Paper link: http://openaccess.thecvf.com/content_ICCV_2019/papers/Peng_Correlation_Congruence_for_Knowledge_Distillation_ICCV_2019_paper.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
'''
CC with P-order Taylor Expansion of Gaussian RBF kernel
'''
class CC(nn.Module):
'''
Correlation Congruence for Knowledge Distillation
http://openaccess.thecvf.com/content_ICCV_2019/papers/
Peng_Correlation_Congruence_for_Knowledge_Distillation_ICCV_2019_paper.pdf
'''
def __init__(self, gamma, P_order):
super(CC, self).__init__()
self.gamma = gamma
self.P_order = P_order
def forward(self, feat_s, feat_t):
corr_mat_s = self.get_correlation_matrix(feat_s)
corr_mat_t = self.get_correlation_matrix(feat_t)
loss = F.mse_loss(corr_mat_s, corr_mat_t)
return loss
def get_correlation_matrix(self, feat):
feat = F.normalize(feat, p=2, dim=-1)
sim_mat = torch.matmul(feat, feat.t())
corr_mat = torch.zeros_like(sim_mat)
for p in range(self.P_order+1):
corr_mat += math.exp(-2*self.gamma) * (2*self.gamma)**p / \
math.factorial(p) * torch.pow(sim_mat, p)
return corr_mat
15、Others
Paper link: https://arxiv.org/pdf/1811.08051.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
'''
LwM is originally an incremental learning method with
classification/distillation/attention distillation losses.
Here, LwM is only defined as the Grad-CAM based attention distillation.
'''
class LwM(nn.Module):
'''
Learning without Memorizing
https://arxiv.org/pdf/1811.08051.pdf
'''
def __init__(self):
super(LwM, self).__init__()
def forward(self, out_s, fm_s, out_t, fm_t, target):
target_out_t = torch.gather(out_t, 1, target.view(-1, 1))
grad_fm_t = grad(outputs=target_out_t, inputs=fm_t,
grad_outputs=torch.ones_like(target_out_t),
create_graph=True, retain_graph=True, only_inputs=True)[0]
weights_t = F.adaptive_avg_pool2d(grad_fm_t, 1)
cam_t = torch.sum(torch.mul(weights_t, grad_fm_t), dim=1, keepdim=True)
cam_t = F.relu(cam_t)
cam_t = cam_t.view(cam_t.size(0), -1)
norm_cam_t = F.normalize(cam_t, p=2, dim=1)
target_out_s = torch.gather(out_s, 1, target.view(-1, 1))
grad_fm_s = grad(outputs=target_out_s, inputs=fm_s,
grad_outputs=torch.ones_like(target_out_s),
create_graph=True, retain_graph=True, only_inputs=True)[0]
weights_s = F.adaptive_avg_pool2d(grad_fm_s, 1)
cam_s = torch.sum(torch.mul(weights_s, grad_fm_s), dim=1, keepdim=True)
cam_s = F.relu(cam_s)
cam_s = cam_s.view(cam_s.size(0), -1)
norm_cam_s = F.normalize(cam_s, p=2, dim=1)
loss = F.l1_loss(norm_cam_s, norm_cam_t.detach())
return loss
16、IRG
Paper link: http://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_Knowledge_Distillation_via_Instance_Relationship_Graph_CVPR_2019_paper.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class IRG(nn.Module):
'''
Knowledge Distillation via Instance Relationship Graph
http://openaccess.thecvf.com/content_CVPR_2019/papers/
Liu_Knowledge_Distillation_via_Instance_Relationship_Graph_CVPR_2019_paper.pdf
The official code is written by Caffe
https://github.com/yufanLIU/IRG
'''
def __init__(self, w_irg_vert, w_irg_edge, w_irg_tran):
super(IRG, self).__init__()
self.w_irg_vert = w_irg_vert
self.w_irg_edge = w_irg_edge
self.w_irg_tran = w_irg_tran
def forward(self, irg_s, irg_t):
fm_s1, fm_s2, feat_s, out_s = irg_s
fm_t1, fm_t2, feat_t, out_t = irg_t
loss_irg_vert = F.mse_loss(out_s, out_t)
irg_edge_feat_s = self.euclidean_dist_feat(feat_s, squared=True)
irg_edge_feat_t = self.euclidean_dist_feat(feat_t, squared=True)
irg_edge_fm_s1 = self.euclidean_dist_fm(fm_s1, squared=True)
irg_edge_fm_t1 = self.euclidean_dist_fm(fm_t1, squared=True)
irg_edge_fm_s2 = self.euclidean_dist_fm(fm_s2, squared=True)
irg_edge_fm_t2 = self.euclidean_dist_fm(fm_t2, squared=True)
loss_irg_edge = (F.mse_loss(irg_edge_feat_s, irg_edge_feat_t) +
F.mse_loss(irg_edge_fm_s1, irg_edge_fm_t1 ) +
F.mse_loss(irg_edge_fm_s2, irg_edge_fm_t2 )) / 3.0
irg_tran_s = self.euclidean_dist_fms(fm_s1, fm_s2, squared=True)
irg_tran_t = self.euclidean_dist_fms(fm_t1, fm_t2, squared=True)
loss_irg_tran = F.mse_loss(irg_tran_s, irg_tran_t)
# print(self.w_irg_vert * loss_irg_vert)
# print(self.w_irg_edge * loss_irg_edge)
# print(self.w_irg_tran * loss_irg_tran)
# print()
loss = (self.w_irg_vert * loss_irg_vert +
self.w_irg_edge * loss_irg_edge +
self.w_irg_tran * loss_irg_tran)
return loss
def euclidean_dist_fms(self, fm1, fm2, squared=False, eps=1e-12):
'''
Calculating the IRG Transformation, where fm1 precedes fm2 in the network.
'''
if fm1.size(2) > fm2.size(2):
fm1 = F.adaptive_avg_pool2d(fm1, (fm2.size(2), fm2.size(3)))
if fm1.size(1) < fm2.size(1):
fm2 = (fm2[:,0::2,:,:] + fm2[:,1::2,:,:]) / 2.0
fm1 = fm1.view(fm1.size(0), -1)
fm2 = fm2.view(fm2.size(0), -1)
fms_dist = torch.sum(torch.pow(fm1-fm2, 2), dim=-1).clamp(min=eps)
if not squared:
fms_dist = fms_dist.sqrt()
fms_dist = fms_dist / fms_dist.max()
return fms_dist
def euclidean_dist_fm(self, fm, squared=False, eps=1e-12):
'''
Calculating the IRG edge of feature map.
'''
fm = fm.view(fm.size(0), -1)
fm_square = fm.pow(2).sum(dim=1)
fm_prod = torch.mm(fm, fm.t())
fm_dist = (fm_square.unsqueeze(0) + fm_square.unsqueeze(1) - 2 * fm_prod).clamp(min=eps)
if not squared:
fm_dist = fm_dist.sqrt()
fm_dist = fm_dist.clone()
fm_dist[range(len(fm)), range(len(fm))] = 0
fm_dist = fm_dist / fm_dist.max()
return fm_dist
def euclidean_dist_feat(self, feat, squared=False, eps=1e-12):
'''
Calculating the IRG edge of feat.
'''
feat_square = feat.pow(2).sum(dim=1)
feat_prod = torch.mm(feat, feat.t())
feat_dist = (feat_square.unsqueeze(0) + feat_square.unsqueeze(1) - 2 * feat_prod).clamp(min=eps)
if not squared:
feat_dist = feat_dist.sqrt()
feat_dist = feat_dist.clone()
feat_dist[range(len(feat)), range(len(feat))] = 0
feat_dist = feat_dist / feat_dist.max()
return feat_dist
17、WID
Paper link: https://openaccess.thecvf.com/content_CVPR_2019/papers/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def conv1x1(in_channels, out_channels):
return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=1,
padding=0, bias=False)
'''
Modified from https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/VID.py
'''
class VID(nn.Module):
'''
Variational Information Distillation for Knowledge Transfer
https://zpascal.net/cvpr2019/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.pdf
'''
def __init__(self, in_channels, mid_channels, out_channels, init_var, eps=1e-6):
super(VID, self).__init__()
self.eps = eps
self.regressor = nn.Sequential(*[
conv1x1(in_channels, mid_channels),
# nn.BatchNorm2d(mid_channels),
nn.ReLU(),
conv1x1(mid_channels, mid_channels),
# nn.BatchNorm2d(mid_channels),
nn.ReLU(),
conv1x1(mid_channels, out_channels),
])
self.alpha = nn.Parameter(
np.log(np.exp(init_var-eps)-1.0) * torch.ones(out_channels)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.BatchNorm2d):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
def forward(self, fm_s, fm_t):
pred_mean = self.regressor(fm_s)
pred_var = torch.log(1.0+torch.exp(self.alpha)) + self.eps
pred_var = pred_var.view(1, -1, 1, 1)
neg_log_prob = 0.5 * (torch.log(pred_var) + (pred_mean-fm_t)**2 / pred_var)
loss = torch.mean(neg_log_prob)
return loss
18、OFD
Paper link: http://openaccess.thecvf.com/content_ICCV_2019/papers/Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
'''
Modified from https://github.com/clovaai/overhaul-distillation/blob/master/CIFAR-100/distiller.py
'''
class OFD(nn.Module):
'''
A Comprehensive Overhaul of Feature Distillation
http://openaccess.thecvf.com/content_ICCV_2019/papers/
Heo_A_Comprehensive_Overhaul_of_Feature_Distillation_ICCV_2019_paper.pdf
'''
def __init__(self, in_channels, out_channels):
super(OFD, self).__init__()
self.connector = nn.Sequential(*[
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels)
])
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, fm_s, fm_t):
margin = self.get_margin(fm_t)
fm_t = torch.max(fm_t, margin)
fm_s = self.connector(fm_s)
mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
loss = torch.mean((fm_s - fm_t)**2 * mask)
return loss
def get_margin(self, fm, eps=1e-6):
mask = (fm < 0.0).float()
masked_fm = fm * mask
margin = masked_fm.sum(dim=(0,2,3), keepdim=True) / (mask.sum(dim=(0,2,3), keepdim=True)+eps)
return margin
19、AFD
Paper link: https://openreview.net/pdf?id=ryxyCeHtPB
code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
'''
In the original paper, AFD is one of components of AFDS.
AFDS: Attention Feature Distillation and Selection
AFD: Attention Feature Distillation
AFS: Attention Feature Selection
We find the original implementation of attention is unstable, thus we replace it with a SE block.
'''
class AFD(nn.Module):
'''
Pay Attention to Features, Transfer Learn Faster CNNs
https://openreview.net/pdf?id=ryxyCeHtPB
'''
def __init__(self, in_channels, att_f):
super(AFD, self).__init__()
mid_channels = int(in_channels * att_f)
self.attention = nn.Sequential(*[
nn.Conv2d(in_channels, mid_channels, 1, 1, 0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, in_channels, 1, 1, 0, bias=True)
])
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, fm_s, fm_t, eps=1e-6):
fm_t_pooled = F.adaptive_avg_pool2d(fm_t, 1)
rho = self.attention(fm_t_pooled)
# rho = F.softmax(rho.squeeze(), dim=-1)
rho = torch.sigmoid(rho.squeeze())
rho = rho / torch.sum(rho, dim=1, keepdim=True)
fm_s_norm = torch.norm(fm_s, dim=(2,3), keepdim=True)
fm_s = torch.div(fm_s, fm_s_norm+eps)
fm_t_norm = torch.norm(fm_t, dim=(2,3), keepdim=True)
fm_t = torch.div(fm_t, fm_t_norm+eps)
loss = rho * torch.pow(fm_s-fm_t, 2).mean(dim=(2,3))
loss = loss.sum(1).mean(0)
return loss
20、CRD
Paper link: https://openreview.net/pdf?id=SkgpBJrtvS
code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
'''
Modified from https://github.com/HobbitLong/RepDistiller/tree/master/crd
'''
class CRD(nn.Module):
'''
Contrastive Representation Distillation
https://openreview.net/pdf?id=SkgpBJrtvS
includes two symmetric parts:
(a) using teacher as anchor, choose positive and negatives over the student side
(b) using student as anchor, choose positive and negatives over the teacher side
Args:
s_dim: the dimension of student's feature
t_dim: the dimension of teacher's feature
feat_dim: the dimension of the projection space
nce_n: number of negatives paired with each positive
nce_t: the temperature
nce_mom: the momentum for updating the memory buffer
n_data: the number of samples in the training set, which is the M in Eq.(19)
'''
def __init__(self, s_dim, t_dim, feat_dim, nce_n, nce_t, nce_mom, n_data):
super(CRD, self).__init__()
self.embed_s = Embed(s_dim, feat_dim)
self.embed_t = Embed(t_dim, feat_dim)
self.contrast = ContrastMemory(feat_dim, n_data, nce_n, nce_t, nce_mom)
self.criterion_s = ContrastLoss(n_data)
self.criterion_t = ContrastLoss(n_data)
def forward(self, feat_s, feat_t, idx, sample_idx):
feat_s = self.embed_s(feat_s)
feat_t = self.embed_t(feat_t)
out_s, out_t = self.contrast(feat_s, feat_t, idx, sample_idx)
loss_s = self.criterion_s(out_s)
loss_t = self.criterion_t(out_t)
loss = loss_s + loss_t
return loss
class Embed(nn.Module):
def __init__(self, in_dim, out_dim):
super(Embed, self).__init__()
self.linear = nn.Linear(in_dim, out_dim)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.linear(x)
x = F.normalize(x, p=2, dim=1)
return x
class ContrastLoss(nn.Module):
'''
contrastive loss, corresponding to Eq.(18)
'''
def __init__(self, n_data, eps=1e-7):
super(ContrastLoss, self).__init__()
self.n_data = n_data
self.eps = eps
def forward(self, x):
bs = x.size(0)
N = x.size(1) - 1
M = float(self.n_data)
# loss for positive pair
pos_pair = x.select(1, 0)
log_pos = torch.div(pos_pair, pos_pair.add(N / M + self.eps)).log_()
# loss for negative pair
neg_pair = x.narrow(1, 1, N)
log_neg = torch.div(neg_pair.clone().fill_(N / M), neg_pair.add(N / M + self.eps)).log_()
loss = -(log_pos.sum() + log_neg.sum()) / bs
return loss
class ContrastMemory(nn.Module):
def __init__(self, feat_dim, n_data, nce_n, nce_t, nce_mom):
super(ContrastMemory, self).__init__()
self.N = nce_n
self.T = nce_t
self.momentum = nce_mom
self.Z_t = None
self.Z_s = None
stdv = 1. / math.sqrt(feat_dim / 3.)
self.register_buffer('memory_t', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))
self.register_buffer('memory_s', torch.rand(n_data, feat_dim).mul_(2 * stdv).add_(-stdv))
def forward(self, feat_s, feat_t, idx, sample_idx):
bs = feat_s.size(0)
feat_dim = self.memory_s.size(1)
n_data = self.memory_s.size(0)
# using teacher as anchor
weight_s = torch.index_select(self.memory_s, 0, sample_idx.view(-1)).detach()
weight_s = weight_s.view(bs, self.N + 1, feat_dim)
out_t = torch.bmm(weight_s, feat_t.view(bs, feat_dim, 1))
out_t = torch.exp(torch.div(out_t, self.T)).squeeze().contiguous()
# using student as anchor
weight_t = torch.index_select(self.memory_t, 0, sample_idx.view(-1)).detach()
weight_t = weight_t.view(bs, self.N + 1, feat_dim)
out_s = torch.bmm(weight_t, feat_s.view(bs, feat_dim, 1))
out_s = torch.exp(torch.div(out_s, self.T)).squeeze().contiguous()
# set Z if haven't been set yet
if self.Z_t is None:
self.Z_t = (out_t.mean() * n_data).detach().item()
if self.Z_s is None:
self.Z_s = (out_s.mean() * n_data).detach().item()
out_t = torch.div(out_t, self.Z_t)
out_s = torch.div(out_s, self.Z_s)
# update memory
with torch.no_grad():
pos_mem_t = torch.index_select(self.memory_t, 0, idx.view(-1))
pos_mem_t.mul_(self.momentum)
pos_mem_t.add_(torch.mul(feat_t, 1 - self.momentum))
pos_mem_t = F.normalize(pos_mem_t, p=2, dim=1)
self.memory_t.index_copy_(0, idx, pos_mem_t)
pos_mem_s = torch.index_select(self.memory_s, 0, idx.view(-1))
pos_mem_s.mul_(self.momentum)
pos_mem_s.add_(torch.mul(feat_s, 1 - self.momentum))
pos_mem_s = F.normalize(pos_mem_s, p=2, dim=1)
self.memory_s.index_copy_(0, idx, pos_mem_s)
return out_s, out_t
21、DML
Paper link: https://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf
Code:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
'''
DML with only two networks
'''
class DML(nn.Module):
'''
Deep Mutual Learning
https://zpascal.net/cvpr2018/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf
'''
def __init__(self):
super(DML, self).__init__()
def forward(self, out1, out2):
loss = F.kl_div(F.log_softmax(out1, dim=1),
F.softmax(out2, dim=1),
reduction='batchmean')
return loss