Channel-wise Knowledge Distillation for Dense Prediction (ICCV 2021) Principle and Code Analysis

paper:Channel-wise Knowledge Distillation for Dense Prediction

official implementation:https://github.com/irfanICMLL/TorchDistiller/tree/main/SemSeg-distill 

Summary 

Most previous distillation methods for dense prediction tasks align the activation maps of the teacher and student networks in the spatial domain, by normalizing the activation values ​​​​of each spatial location, and reducing point-wise and/or The difference between pair-wise to achieve knowledge transfer.

Different from the previous method, this paper proposes to standardize the activation map of each channel to obtain a soft probability map. By reducing the KL divergence between the two network channel-wise probability maps, the distillation process pays more attention to each channel. The most salient regions, which are very valuable for dense prediction tasks.

background

Dense prediction tasks are pixel-level prediction problems, which are more challenging than image-level classification problems. Previous studies have found that the effect of directly applying the distillation method in classification to semantic segmentation is unsatisfactory. Strictly aligning poit-wise classification scores or feature maps between teacher and student networks may impose too strict constraints and only lead to suboptimal solutions.

Some recent studies mainly focus on strengthening the correlation between different spatial locations. As shown in Figure 2(a), the activation values ​​at each spatial location are normalized and then some task-specific relations, such as pair-wise relations and inter-class relations, are obtained by aggregating subsets of different spatial locations. These methods may perform better than point-wise alignment in capturing spatial structure information and improving the performance of student networks. However, each spatial location in the activation map contributes equally to knowledge transfer, which may bring redundant information from the teacher network. 

In this paper, we propose a new channel-level channel-wise knowledge distillation method for dense prediction tasks by normalizing the activation maps in each channel, as shown in Figure 2(b). The KL divergence between the normalized channel activation maps of the teacher and student networks is then reduced. An example of channel-wise distribution is shown in Fig. 2(c), where it can be seen that the activation maps of each channel tend to encode salient regions for specific scene categories. For each channel, the student network is guided to pay more attention to regions with large simulated activation values, leading to more accurate localization in dense prediction tasks. For example, in the target detection task, let the student network pay more attention to the activation of the learning foreground area.

Contribution to this article

  • Different from existing spatial distillation methods, this paper proposes a new channel distillation paradigm for dense prediction tasks, which is simple and effective.
  • In terms of semantic segmentation and object detection, the channel-level distillation method proposed in this paper significantly outperforms the state-of-the-art KD method.
  • We achieve consistent improvements on four benchmark datasets on semantic segmentation and object detection tasks, demonstrating the generalizability of our approach. Given its simplicity and effectiveness, we believe our method can serve as a strong baseline KD method for dense prediction tasks.

method introduction

In order to better utilize the knowledge in each channel, the authors propose to softly align the activations of the corresponding channels of the teacher network and the student network. To do this, the activations of a channel are first converted to a probability distribution, so that differences can be measured using a probabilistic distance metric such as KL divergence. As shown in Fig. 2(c), the activations of different channels tend to encode salient saliency regions of a certain class of scenes in the input image. Furthermore, a trained teacher model for semantic segmentation shows clear activation maps of category-specific masks in each channel, which is as expected, as shown on the right side of Figure 1. Therefore, the authors propose a new channel distillation paradigm to guide students to learn knowledge from a well-trained teacher.

First define the teacher network and student network as \(T\) and \(S\) respectively, and the activations of \(T\) and \(S\) are represented as \(y^{T}\) and \(y ^{S}\), the general expression of the channel distillation loss is as follows

Where \(\phi(\cdot)\) is used to convert the activation value into a probability distribution, as follows

Where \(c=1,2,...,C\) represents the channel, \(i\) represents the index of the spatial location of a channel, and \(\mathcal{T}\) is the temperature hyperparameter. If we use a larger \(\mathcal{T}\), the probability distribution becomes softer, meaning that the spatial region of interest in each channel is wider. By using softmax normalization, the difference in scale between large and small networks is eliminated. If the channels of the teacher network and the student network do not match, then use 1x1 convolution to upsample the number of channels of the small network to make the two equal. \(\varphi(\cdot)\) evaluates the difference between the channel distributions of the teacher model and the student model, specifically using the KL divergence

KL divergence is an asymmetric measure. It can be seen from formula 4 that if \(\phi(y^{T}_{c,i})\) is very large, \(\phi(y^{S}_{c,i})\) should As large as \(\phi(y^{T}_{c,i})\) to reduce the KL divergence. On the contrary, if \(\phi(y^{T}_{c,i})\) is very small, the KL divergence pair decreases\(\phi(y^{S}_{c,i})\) relatively little attention. With the supervision of the teacher network, the student network tends to produce similar activation distributions as the teacher network in the foreground salient regions, while the activations in the background regions of the teacher network have less influence on the student network. The authors argue that this asymmetry of KL facilitates the learning of distillation in dense prediction tasks.

The difference between channel-wise distillation and classification tasks

In the article Channel Distillation: Channel-Wise Attention for Knowledge Distillation , the author also proposed the use of channel distillation, but it is mainly applied to classification tasks. Inspired by SENet, the feature map of one channel is converted into a scalar by global average pooling, and then KL divergence is applied to measure the difference between the scalar of the corresponding channel of the teacher network and the student network.

This article mainly considers dense prediction tasks. GAP may be helpful for image-level classification tasks, but the weights of all spatial positions are the same, and spatial information is lost, so it is not suitable for dense prediction tasks. In this paper, through softmax standardization, the importance of different spatial locations is considered, and the spatial location information is preserved, so it is more suitable for dense prediction tasks.

Experimental results

Table 2 is a comparison of the complexity of the channel distillation proposed in this paper and other spatial distillations on the Cityscapes dataset and the mIoU on the verification set. It can be seen that the channel distillation has the highest accuracy and low complexity

Table 5 is a comparison of the accuracy of different student models using different distillation methods on the Cityscapes dataset. It can be seen that for student networks with different structures, the effect of channel distillation proposed in this paper is better than other distillation methods.

 

Table 6 is a comparison with other distillation methods on the target detection task. It can be seen that the channel distillation proposed in this paper has the best effect in the target detection models with two-stage, single-stage, and anchor-free structures.

code analysis

The official implementation is as follows, where channel_norm loss is used for normalization and KL loss is the method given in the paper. Other normalization methods and loss function choices are also given in the official implementation. 

import torch.nn as nn


class ChannelNorm(nn.Module):
    def __init__(self):
        super(ChannelNorm, self).__init__()

    def forward(self, featmap):
        n, c, h, w = featmap.shape
        featmap = featmap.reshape((n, c, -1))
        featmap = featmap.softmax(dim=-1)
        return featmap


class CriterionCWD(nn.Module):

    def __init__(self, norm_type='none', divergence='mse', temperature=1.0):

        super(CriterionCWD, self).__init__()

        # define normalize function
        if norm_type == 'channel':
            self.normalize = ChannelNorm()
        elif norm_type == 'spatial':
            self.normalize = nn.Softmax(dim=1)
        elif norm_type == 'channel_mean':
            self.normalize = lambda x: x.view(x.size(0), x.size(1), -1).mean(-1)
        else:
            self.normalize = None
        self.norm_type = norm_type

        self.temperature = 1.0

        # define loss function
        if divergence == 'mse':
            self.criterion = nn.MSELoss(reduction='sum')
        elif divergence == 'kl':
            self.criterion = nn.KLDivLoss(reduction='sum')
            self.temperature = temperature
        self.divergence = divergence

    def forward(self, preds_S, preds_T):

        n, c, h, w = preds_S.shape
        # import pdb;pdb.set_trace()
        if self.normalize is not None:
            norm_s = self.normalize(preds_S / self.temperature)
            norm_t = self.normalize(preds_T.detach() / self.temperature)
        else:
            norm_s = preds_S[0]
            norm_t = preds_T[0].detach()

        if self.divergence == 'kl':
            norm_s = norm_s.log()
        loss = self.criterion(norm_s, norm_t)

        # item_loss = [round(self.criterion(norm_t[0][0].log(),norm_t[0][i]).item(),4) for i in range(c)]
        # import pdb;pdb.set_trace()
        if self.norm_type == 'channel' or self.norm_type == 'channel_mean':
            loss /= n * c
            # loss /= n * h * w
        else:
            loss /= n * h * w

        return loss * (self.temperature ** 2)

 

Guess you like

Origin blog.csdn.net/ooooocj/article/details/130447878