论文-知识蒸馏

Structured Knowledge Distillation for Dense Prediction

今天看一篇沈老师他们的工作,是知识蒸馏(knowledge distillation)相关

github传送

  • Structured Knowledge Distillation for Dense Prediction

Previous knowledge distillation strategies used for dense prediction tasks often directly borrow the distillation scheme for image classification and perform knowledge distillation for each pixel separately, leading to sub-optimal performance

文章摘要说,以前的KD都是对每一个像素学习知识,会得到一个次优的解。他说的是对于dense prediction。这是为什么呢?

Here we propose to distill structured knowledge from large networks to small networks, taking into account the fact that dense prediction is a structured prediction problem.

由于dense prediction就是一个结构预测的问题,所以提出了一个‘蒸馏结构知识’的方法。有两种结构蒸馏方案,他管以前的KD叫做pixel-wise distillation

  1. pair-wise distillation

The pair-wise distillation scheme is motivated by the widely-studied pair-wise Markov random field framework. 引文23

  1. holistic distillation

The holistic distillation scheme aims to align higher-order consistencies

Specifically, we study two structured distillation schemes:
i) pair-wise distillation that distills the pairwise similarities by building a static graph;
and ii) holistic distillation that uses adversarial training to distill holistic knowledge.


  • Dense Prediction

Dense prediction is a category of fundamental problems in computer vision, which learns a mapping from input objects to complex output structures, including semantic segmentation, depth estimation and object detection, among many others.

将输入映射为复杂的结构输出,那么他的这种结构蒸馏好像不是我们图像恢复需要的?(思想好像可以照搬,但是他的设计可能更关注dense structure)


基于以上考虑,看了下pipeline
SKD
除了holistic loss不清楚怎么算的,其他好像很有道理。


好像文章都很热衷于各种trick疯狂刷指标。。。。
文章将他应用在语义分割,深度估计,目标检测


  • 方法
    I W × H × 3 W\times H\times 3 RGB输入
    F W × H × N W\times H\times N 输入I的feature map
    Q W × H × C W\times H\times C F经过a classifier计算的分割map???上采样到 W × H W\times H 作为分割结果。。。
    不知所云,往后看

  • Pixel-wise distillation
    计算S和T输出的概率图之间KL散度
    pi

  • Pair-wise distillation
    static affinity graph表示空间成对的关系
    pair-wise distillation
    lpa
    β \beta 使用平均池化
    C C 矩阵相乘消除通道
    但是前面 W × H × α β \frac{W'\times H'\times \alpha}{\beta} 是怎么回事,不是应该除以他吗?而且从 a i , j a_{i,j} 的定义来看不是还得除以C(待读代码解惑)

  • Holistic distillation
    使用全局蒸馏时,使用了conditional GAN,由于离散的JS散度,使用Wasserstein distance or Earth Mover distance
    lho
    Discriminator使用self-attention residual block,self-attention和residual block的位置和数量见论文
    以上,很多概念都不是特别懂
  • conditional GAN
  • Wasserstein distance
  • Q s Q^{s} Q t Q^{t} 是怎么embedding的

  • Optimization
    loss
    λ 2 \lambda_{2} 前为负号?(对于生成网络似乎是这样)
    discriminator
    conpact

以上大概是文章的介绍部分,有很多地方还不是很清楚,开始读代码。。。。


  • self-attention模块
    具体原理待看,可以参考这个博客,其实还是不懂,因为对RNN不太理解,代码来看就是下面这个过程,先跑起来,然后再看,MARK
    selfattn
    这个辣鸡图softmax那里画错了,就不要看了
class selfAttn(nn.Module):
    def __init__(self, dim):
        super(selfAttn, self).__init__()
        self.query_conv = nn.Conv2d(dim, dim//8, 1)
        self.key_conv = nn.Conv2d(dim, dim//8, 1)
        self.value_conv = nn.Conv2d(dim, dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
        #这里dim=-1其实是每一行的所有列进行softmax
    def forward(self, x):
        n, c, h, w = x.shape
        proj_query = self.query_conv(x).view(n, -1, h*w).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(n, -1, h*w)
        energy = torch.bmm(proj_query, proj_key) #nxwhxwh
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(n, -1, h*w)
        
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        print(out.shape)
        out = out.view(n, c, h, w)
        out = self.gamma*out + x
        print(proj_query.shape, proj_key.shape, energy.shape, attention.shape, proj_key.shape)
        return out, attention

基于知识蒸馏的超分辨率卷积神经网络压缩方法

上面的文章和这篇文章都是知其然而不知其所以然


  • motivation
    paper
    他有一篇引文说不同层次的特征图代表不同的信息,且传递中间的feature会比输出更有效,所以它采用的以下方法,使用三个层次的信息
  • 网络结构
    paper

  • loss函数
    loss
    以上超参 0 = 1+2+3,1:1就好
    loss

  • training
  1. 特征图统计
    picture
    结论:picture,这个统计量最优
    paper
    从上面来看这个应该是按通道的特征统计
  2. teacher网络的大小
    teacher网络越大越好,但是过大以后提升越来越小

Lightweight Image Super-Resolution with Information Multi-distillation Network

都找不到low-level知识蒸馏的文章,在github awesome-knowledge-distillation项目中只找到一篇超分相关的,还是西电的2333,看一下吧


recently, Zhang et al. also introduced spatial attention (non-local module) into the residual block and then constructed residual non-local attention network (RNAN) [37] for various image restoration tasks.

spatial attention for various image restoration tasks ——可以关注下


文章想要解决计算量的问题从

  • 参数量 parameters
  • depth
  • resolution
    方面进行设计
  • 文章受启发与IDN提出了IMDN(infomation multi-distillation network)
    ,构建了IMDB(blocks)提取信息包括两部分,一部分retains partial information,一部分feature treats other features,
  • aggregating features distilled这个就叫做蒸馏?
    他使用了contrast-aware channel attention layer, specifically related to the low-level vision task,他说这种和低级的视觉任务很相关,关注!

文章的contributions,最后一个是认真的?网络的深度和推断速度相关…可能我理解的错了,看他的实验吧
paper


  • 网络框架
    在这里插入图片描述
    IMDBCCA
    上面依次是网络主体IMDN,上采样UP,信息多蒸馏块IMDB,对比度感知通道注意层CCA
    疑问:1. channel split?就是将channel分为两部分(64=16+48),前面的层表示refined features
    2. contrast

在这里插入图片描述
回答上面疑问2:
CCA
CCA做了上面这样一件事,按channel统计均值和方差作为attention相加,总感觉这样的加法或者乘法没有什么道理,并不能说服我,但是他是网络学习出来的特征,可能代表一些东西。上面公式计算出来的结果是 c × 1 × 1 c\times1\times1 ?


  • Adaptive cropping stategy ACS-IMDN_AS
    他的思想就是将图像裁成四块并行计算,允许一份的重叠,计算完成后再舍弃重叠部分拼接
    他这是任意尺寸的图像放大固定倍数,而不是任意倍数的超分。。。。失望

好,到这里我们大概了解了知识蒸馏的做法,人们将他归为transfer learning迁移学习,就是将大网络的先验知识传送给小网络,那么这样做为什么可行呢??然我们从头开始学习**Knowlledge Distillation**

发布了20 篇原创文章 · 获赞 0 · 访问量 367

猜你喜欢

转载自blog.csdn.net/yywxl/article/details/103680152