FSP:Flow of Solution Procedure (CVPR 2017) 原理与代码解析

paper:A Gift From Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning

code:https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/FSP.py

背景

深度神经网络DNN逐层生成特征。更高层的特征更接近于任务的有用特征。如果我们把DNN的输入看作问题,把输出看作答案,我们就可以把DNN中间生成的特征看作是求解过程中的中间结果。根据这一想法,FitNets可以让学生网络简单地模拟教师网络的中间结果。然而在DNN中,有许多方法或途径来解决从输入生成输出的问题。因此,模拟教师网络生成的特征对学生网络来说是一个硬约束hard constraint。就人而言,老师解释问题的解决过程,学生学习解决问题的流程。当输入特定的问题时,学生网络不一定需要学习中间输出,但当遇到特定类型的问题时,学生网络可以学习这一类问题的通用解决方法。因此作者认为,对于知识蒸馏中的教师网络,演示问题的解决过程比演示中间结果具有更好的泛化性

本文的创新点

本文将神经网络中层与层之间的信息流动定义为需要蒸馏的知识,并通过计算两个特征层之间的内积来得到这种知识。当将这种层之间的流动作为知识传递给学生网络时,作者通过实验得到了三个结论:

  1. 从教师网络学习这种蒸馏知识的学生网络比原始网络的优化(收敛)速度快得多。

  1. 学习这种蒸馏知识的学生网络比原始网络的性能更好。

  1. 即使教师网络是在一个不同的任务或数据集上训练得到的,学生网络也可以从教师网络中学习到这种知识,并且比从头训练的效果更好。

下图是本文提出的知识蒸馏方法的概念图

本文的贡献如下:

  1. 提出了一种知识蒸馏的新方法。

  1. 这种知识对于快速优化非常有用。

  1. 利用所提出的蒸馏知识定义网络的初始权重可以提高小模型的性能。

  1. 即使学生网络接受了与教师网络不同的训练任务,所提出的蒸馏知识也能提高学生网络的表现。

方法介绍

作者设计了网络中两个相邻层之间的FSP(flow of solution procedure)矩阵来表示问题的求解过程,对于挑选的层1输出的feature map表示为 \(F^{1}\in \mathbb{R}^{h\times w\times m}\),其中 \(h,w,m\) 分别表示特征图的高、宽、通道数。层2表示为 \(F^{2}\in \mathbb{R}^{h\times w\times n}\),则FSP矩阵 \(G\in \mathbb{R}^{m\times n}\) 可通过下式求得

其中 \(x\) 表示输入图片,\(W\) 表示网络权重参数。

对于残差网络,网络在一些位置的spatial size发生变化,我们选择教师网络和学生网络对应位置具有相同spatial size的特征图来生成FSP matrix,下图是一个示例

计算教师网络和学生网络对应FSP矩阵的L2损失,完整是损失函数如下

其中 \(\lambda_{i}\) 表示每一对FSP矩阵损失的权重,文中设定所有层计算的FSP之间的损失权重相等。\(N\) 表示所有的采样点。

代码解析

forward函数的输入g_sg_t分别表示学生网络和教师网络中所有用来计算FSP矩阵的层,在compute_fsp中每一层都与相邻层计算fsp矩阵,注意这里的相邻并不是说在原始网络中这两层的相邻的。这里相邻层之间计算fsp矩阵需要保证spatial size相等,如果不相等通过自适应平均池化使之相等。

from __future__ import print_function

import numpy as np
import torch.nn as nn
import torch.nn.functional as F


class FSP(nn.Module):
    """A Gift from Knowledge Distillation:
    Fast Optimization, Network Minimization and Transfer Learning"""
    def __init__(self, s_shapes, t_shapes):
        super(FSP, self).__init__()
        assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
        s_c = [s[1] for s in s_shapes]
        t_c = [t[1] for t in t_shapes]
        if np.any(np.asarray(s_c) != np.asarray(t_c)):
            raise ValueError('num of channels not equal (error in FSP)')

    def forward(self, g_s, g_t):
        # [(64,32,32,32),(64,64,32,32),(64,128,16,16),(64,256,8,8)]
        # [(64,32,32,32),(64,64,32,32),(64,128,16,16),(64,256,8,8)]
        s_fsp = self.compute_fsp(g_s)
        t_fsp = self.compute_fsp(g_t)
        loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
        return loss_group

    @staticmethod
    def compute_loss(s, t):
        return (s - t).pow(2).mean()

    @staticmethod
    def compute_fsp(g):
        fsp_list = []
        for i in range(len(g) - 1):
            bot, top = g[i], g[i + 1]  # (64,32,32,32),(64,64,32,32)
            b_H, t_H = bot.shape[2], top.shape[2]
            if b_H > t_H:
                bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
            elif b_H < t_H:
                top = F.adaptive_avg_pool2d(top, (b_H, b_H))
            else:
                pass
            bot = bot.unsqueeze(1)  # (64,1,32,32,32)
            top = top.unsqueeze(2)  # (64,64,1,32,32)
            bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)  # (64,1,32,1024)
            top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)  # (64,64,1,1024)
            fsp = (bot * top).mean(-1)  # (64,64,32,1024)->(64,64,32)
            fsp_list.append(fsp)
        return fsp_list

猜你喜欢

转载自blog.csdn.net/ooooocj/article/details/129250325