Explanation-Guided Training for Cross-Domain Few-Shot Classification

Explanation-Guided Training for Cross-Domain Few-Shot Classification

Official account: EDPJ

Table of contents

0. Summary

0.1 Explanation of keywords and terms

1 Introduction

2. Related research

2.1 Few-shot Classification(FSC)

2.2 Cross-domain Few-shot Classification(CD-FSC)

2.3 Explanation for FSC

3. Explanation-guided Training

3.1 CD-FSC structure

3.2 Training

4. Experiment

4.1 Dataset and model preparation

4.2 Evaluation

4.3 Combination of Explanation-guided training and LFT

4.4 Analyzing the effect of explanation-guided training

4.5 Quantitative analysis of LRP

5. Reference


0. Summary

The challenges faced by the cross-domain few-shot classification task (CD-FSC) mainly come from: there are few labeled samples (labelled data) in each category, and the training set and testing set belong to different domain. This paper proposes a new training method based on the existing FCS. It uses the explanation obtained when the FSC model predicts, and this value is used in the intermediate feature map of the model. First, we adjust the relevance propagation of each layer to explain the predictions of the FSC model. Second, the authors improved the model-agnostic explanation-guided training strategy: dynamically looking for and emphasizing features that are important for prediction. This study is not aimed at proposing new explanation methods, but focuses on new uses of explanation in the training phase.

0.1 Explanation of keywords and terms

  • cross-domain (cross-domain): the model learned in the source domain (for example: recognition), for another different target domain
  • few-shot (few samples): For models that have been pre-trained (pre-train), only a small amount of labeled data (support set) can be seen to complete the task (query set).
  • N-way K-shot: settings for few-shot learning. The Support set has a total of N categories, and each category has K labeled data.
  • Relevance (importance): The method used in this paper is to enhance the features that are important for prediction (classification) and weaken the features that are not important for prediction. In my opinion, this should also be called confidence, as shown in formula (2), which shows the degree of confidence that the feature belongs to a certain category.
  • BP: This article uses back propagation, which transmits the relevance from the back to the front, and finally obtains the relevance of the initial features. In fact, it is also called belief propagation.
  • Explanation (interpretation): Through BP, the relevance of the initial features is obtained, also known as explanation. This is because: when the prediction is successful, the features corresponding to the category are also the most important for the prediction, and the corresponding relevance is also the largest. This also explains why this class is predicted and not others.
  • information bottleneck​: ​Discard unimportant information and keep important information​. The method used in this paper is also based on this theory. However, the discarded information that is not important for a certain round (episod) prediction may be important for other rounds of prediction​, which leads to overfitting. Therefore, although the method of this work will discard some information, it will not be excessive.

1 Introduction

Humans can recognize new objects after seeing a small number of samples. However, the training and fine-tuning of general classification models require a large amount of labeled data. FSC, on the other hand, can classify new categories based on a small number of samples. After the model is deployed, humans label a small number of samples in new categories that the originally trained model has not seen. Testing data originates from the dataset of the same domain as training data. The challenge facing FSC is generalization from source domain to target domain. For example: humans can identify birds and plants with a small number of samples, while the existing FSC based on bird training may not be able to accurately identify different types of plants.

Solving this problem avoids overfitting to the source domain. This article improves CD-FSC: explanation guides the model to obtain better feature representation. The methods of Explanation are: gradient-type method, Shapley-type method, LRP and LIME. They compute a score for each dimension of a feature map and note its importance to the final prediction.

Although many studies have made great progress in the field of explaining model prediction, they are usually used in the testing phase and not used in the training phase. For example: predictive audit (audit), more comprehensive explanation-weighted documents representation, and identifying biases in datasets.

The FSC model in this paper uses the LRP method. LRP has been used in CNN, RNN, GNN and clustering. It backpropagates the relevance of the target label in the neural network and assigns the relevance to the neurons in the network. The sign and magnitude of Relevance reflect a neuron's contribution to the prediction.

The above picture is the LRP explanation heatmap (heatmap) of the input image (with five target labels). The model used is RelationNet trained on miniImagenet under the setting of 5-way 5-shot (5 categories, five samples per category). The first row is a sample of suport images. The other two lines are explanation heatmaps for the two query images. Both classifications are correct, and the heatmap is generated based on different target labels. Red/blue pixels represent positive/negative LRP explanation scores, respectively. The intensity of the color represents the value of the explanation scores. As shown in the figure, the higher the similarity between the query image and the support image, the more red pixels the redder, and vice versa.

The LRP relevance of the middle feature map is used as a weight to construct an LRP weighted feature map. This step strengthens the feature dimension that is more related to prediction, and reduces the feature dimension that is small and related to prediction. The features weighted by LRP are then fed to the network for training. Because the LRP explanation is calculated for each pair of sample-label, during training, the training guided by the explanation adds a label-dependent weighting mechanism. This mechanism can reduce the overfitting of the source domain.

The explanation-guided training strategy in this paper is not limited by the model (model-agnostic), and can be combined with other CD-FSCs, such as: Learned Feature-wise Transformation (LFT).

2. Related research

2.1 Few-shot Classification(FSC)

Few-shot learning has two directions: optimization-based and metric-based. The former learn initialization parameters that can be quickly transferred to new categories, or design a meta-optimizer that learns how to update model parameters. The latter learns a distance metric, compares the support image to the query image, and assigns the query to the closest category. Other methods are also worth noting, such as:

  • Add a conditional task layer to the model (task-conditional task);
  • Dynamically update the parameters of the classifier for new categories;
  • Combine multi-modal information (e.g., word embedding of category labels);
  • Augmentation of training data by hallucinating new samples;
  • Semi-supervised learning with unlabeled training data;
  • A mechanism to add self-supervised to the model.

However, these methods still have to face the problem of domain migration.

2.2 Cross-domain Few-shot Classification(CD-FSC)

A number of CD-FSC methods have been generated based on existing FSC methods.

  • LFT learns a noise distribution during training, and then adds it to intermediate feature maps to generate more diverse features and improve the generalization performance of the model;
  • Combine multiple encoders and use batch spectral regularization (BSR) for the image features of each encoder : limit the singular values ​​(singular vaules) of the feature matrix in a batch, so that the learned features have similar spectra in different domains ( spectrum). That is, to avoid overfitting of the model and the source domain, thereby improving the generalization performance in the target domain;
  • Methods combined with first-order MAML and metric-based GNNs;
  • Use a prototypical triplet loss to increase the inter-class distance, and use a large margin cosine loss to reduce the intra-class distance. Another related study shows that reducing intra-class variance benefits FSC, especially for shallow feature encoders.

2.3 Explanation for FSC

The FSC model can use CNN to encode image features, and many metric-based methods can use neural networks to learn distance metrics. For the FSC model using a non-parametric distance metric, a study transformed the K-means classifier into a neural network structure, and then used LRP to obtain an explanation. This paper uses LRP because of: its reasonable performance; its understanding of hyperparameters; and its reasonable training speed compared to LIME and Shapley-type methods.

3. Explanation-guided Training

3.1 CD-FSC structure

For a K-way N-shot task, a support set containing K categories and N labeled samples in each category is given for training,  S = \mathop {\{ (\mathop x\nolimits_s ,\mathop y\nolimits_s )\} }\nolimits_{s = 1}^{K \times N}and the same query set as S categories is used for testing  Q = \mathop {\{ (\mathop x\nolimits_q ,\mathop y\nolimits_q )\} }\nolimits_{q = 1}^{\mathop n\nolimits_q }. CD-FSC task:  \mathop D\nolimits_{seen}randomly sampled from a base domain \{ \mathop S\nolimits_i ,\mathop Q\nolimits_i \}, called an episode, used to train the FSC model; then  \mathop D\nolimits_{unseen}sampled on another domain, used to test the model.

The blue path in the figure is the training of FSC, and the red path is the method of explanation (after the completion of the blue path).

Support set S and query set Q are encoded by CNN and may have augmented layers to obtain support image features \mathop f\nolimits_s and query image features  \mathop f\nolimits_q, which must be processed before classification, for example:

  • Based on the category average \mathop f\nolimits_s, then the average class representation is concatenated with \mathop f\nolimits_qthe pair;
  • Design an attention module for generating support/query image features weighted by attention;
  • Apply \mathop f\nolimits_s ,\mathop f\nolimits_qGNN to obtain features with garph structure.

\mathop f\nolimits_pThe classifier predicts (classifies) based on the processed features : using an optimization-based (neural network) method; or a metric-based (Cosine Similarity, Euclidean distances, Mahalanobis distance) method. The predicted result is p.

The Explain module explains the prediction p and generates \mathop f\nolimits_pan explanation for p R(\top f\nolimits_p ), which is used to compute the weights of the LRP \mathop\omega\nolimits_{lrp}.

The features weighted by LRP \source \omega \nolimits_{lrp} \source { \odot f}\nolimits_pare fed to the classifier to update the prediction \mathop p\nolimits_{lrp}.

3.2 Training

Step 1 : Obtain the prediction p in the model based on the forward-pass

Step 2 : Explain the classifier. Initialize LRP relevance for each label, and then use LRP to explain the classifier. As shown in the Explain block in the above figure, the explanation of the input of the classifier can be obtained  R(\top f\nolimits_p ).

Using the neural network as the FSC model of the classifier, the relevance of each label can be initialized with their probability of occurrence ( logits ). For metric-based models, since the predicted values ​​for all labels are positive, this will lead to similar explanations for these labels.

Taking Cosine Similarity as an example, first use the formula (1) to calculate the probability of each category:

Among them, \mathop {cs}\nolimits_k ( \cdot )is the cosine similarity of query sample and category k. \mathop f\nolimits_pare the processed features fed to the classifier. \betais a constant scaling parameter used to enforce the maximum probability. Based on the probability defined by the above formula, the relevance of category c is expressed as:

At the timeP(\mathop y\nolimits_c |\mathop f\nolimits_p ) > 1/K , \mathop R\nolimits_c ,c = 1, \ldots ,Kit was true. In other words, a class label has a positive relevance when its probability is greater than the probability of a random guess. \mathop R\nolimits_cThen, the backpropagation (BP) is passed  through the classifier to finally generate relevance R(\top f\nolimits_p ) . Considering l \to l + 1the forward pass (FP) of the layer is expressed as:

Among them, i,jis l, l + 1the index of the neuron of the first layer, f( \cdot )and is the activation function. Let R( \cdot )denote the relevance of a neuron, and \mathop R\nolimits_{i \leftarrow j}denote \mathop z\nolimits_i^l \leftarrow \mathop z\nolimits_j^{l + 1}the relevance of contribution. Here, rely on the BP mechanism of two LRPs, \mathop {LRP}\nolimits_\varepsilon \mathop {,LRP}\nolimits_\alpha :

1)\mathop {LRP}\nolimits_\varepsilon

 Among them, \varepsilonis a small positive number \varepsilon \odot sign(\mathop y\nolimits_i^{l + 1} )to ensure that the division is not wrong.

2)\mathop {LRP}\nolimits_\alpha

Among them, \alpha \ge 1the ratio of positive relevance to BP is controlled. \top {(*)}\nolimits^ + = \max(*,0)\top {,(*)}\nolimits^ - = \min(*,0).

\mathop z\nolimits_i^lThe relevance is the sum of the contributions of all relevance flowing to it:

To obtain R(\top f\nolimits_p ), this paper uses for linear layer \mathop {LRP}\nolimits_\varepsilonand convolutional layer \mathop {LRP}\nolimits_\alpha. R(\top f\nolimits_p )Normalized by its maximum absolute value.

Step 3: Features of LRP weighting. In order to strengthen the features with high prediction correlation and weaken the low prediction correlation features, LRP weighting and LRP weighting features are defined:

where \odotis the element-wise product. Because after normalization, R(\mathop f\nolimits_p ) \in [ - 1,1, \mathop\omega\nolimits_{lrp}the features with positive relevance are enlarged, and the features with negative relevance are weakened.

Step 4: Finally, pass the LRP weighted features to the classifier to generate explanation-guided predictions \mathop p\nolimits_{lrp}:

 Among them, \mathop L\nolimits_{ce}is cross-entropy loss. \xi ,\lambdais a positive coefficient that controls \top {p,p}\nolimits_{lrp}how much information is used.

4. Experiment

This experiment is carried out on RelationNet (RN) and two latest models: cross attention network (CAN), GNN. The corresponding settings of these three models under the CD-FSC structure are as follows:

In addition, explanation-guided training is combined with LFT, and the performance improvement shows compatibility with LFT.

4.1 Dataset and model preparation

Five datasets: miniImagenet, CUB, Cars, Places, Plantae. miniImagenet is used as the training set and validation set, and the other four are used as the test set.

The image encoders of RN and CAN are ResNet10 and ResNet12 respectively. All three models are trained under the settings of 5-way 5-shot and 5-way 1-shot. The LRP BP parameters used in all experiments are: \alpha {\rm{ = }}1,\varepsilon {\rm{ = }}0.001.

By changing the value in formula (9) total loss \xi ,\lambda, it is observed that: for RN and GNN, two models that use parameter trainable classifiers, it is completely dependent on \mathop L\nolimits_{ce}(y,\mathop p\nolimits_{lrp})(\xi = 0)making the model difficult to converge, and only obtains a small gain; while for CAN this A model that uses a non-parametric classifier such as cosine similarity is not affected. This is because explaining bad classifiers has little meaning and can skew the classifier's parameters from the normal direction from the start, especially with few-shot. Therefore, it is necessary to \mathop L\nolimits_{ce}(y,p)combine with it to stabilize the training and increase \mathop L\nolimits_{ce}(y,p)the proportion of the 1-shot. For RN and GNN, for 5-way 1-shot, set to: \xi = 1, \lambda = 0.5; for 5-way 5-shot, set to: \xi = 1,\lambda = 1. For CAN, \xi = 0,\lambda = 1, formula (1) cosine similarity, \beta = 7.

At test time, 2000 randomly sampled episodes (episodes) were performed with 16 query images per episode.

4.2 Evaluation

For a more comprehensive analysis, use transductive inference (transduction reasoning / transductive learning): In the test phase, query images that have been classified with high confidence are used as support images to enhance the support set. This is an iterative process. This experiment implements transductive with two iterations: the first iteration has 35 such query images, and the second has 70. Because GNN requires a fixed number of support images, only use transductive inference on RN and CAN.

The above picture is about the data of RN and CAN. T in the picture means transductive inference. Adding more support images through this method can indeed improve performance.

The picture above is about GNN data. miniImagenet is the training and validation set, and the other four are the test set. The author gives the reason why the performance on other datasets is not good enough compared to the results on miniImagenet below: The method used by FSC is to remove information that is not related to discrimination, and the information that is useless in one episodes may be in other episodes. Pivotal.

Tell me about my personal opinion.

First: It is possible that the data in miniImagenet is more correlated with the data in the CUB and Places collections, resulting in: In different settings, the performance of these two datasets is significantly better than the other two datasets.

Second: Compared with the original CN, CAN, and GNN, the improved performance based on LRP has little or no improvement. I saw in related materials that this may be because: what the author actually did was to obtain a relatively good feature, rather than really solve the cross-domain and few-sample problems faced by CD-FSC. (As the author mentioned in the introduction)

4.3 Combination of Explanation-guided training and LFT

The LFT model is trained with a pseudo-seen domain and a pseudo-unseen domain. In this experiment, miniImagenet is the pseudo-seen domain, while three of the other four sets are the pseudo-unseen domain, leaving one set for testing. The pseudo-unseen domain is used to train the feature conversion layer, and the pseudo-seen domain is used to update other trainable parameters in the model. If the parameters of the feature conversion layer are fixed, then FT is obtained: adding noise with a fixed distribution in the determined intermediate layer.

As can be seen from the figure above, with the gradual addition of various modules, the performance is gradually improved.

4.4 Analyzing the effect of explanation-guided training

The approach stems from the information bottleneck framework: training a discriminative classifier learns to filter out irrelevant features. Removal of information means that the channel (channel) associated with these information is not activated.

Traditional classification tasks only classify fixed categories, so removing irrelevant information has no effect. For FSC, in different episodes, the category is changed. Information that is useless in one episode may be critical in other episodes, thus causing performance degradation on the test set compared to the validation set.

If the classifier is overfitting and often predicts wrong class labels, explanation-guided training will identify relevant features for wrongly predicted classes and reinforce them, and the subsequent loss will penalize these strengthened features more. This avoids intermediate features from being biased towards a certain class, leading to better generalization performance.

The experimental results show that explanation-guided training can avoid excessive information removal, thereby avoiding overfitting of the source domain.

4.5 Quantitative analysis of LRP

In this section, the LRP explanation of the input image is visualized as a heatmap. From the heatmap, it is easy to observe which part of the image is used for prediction.

The first line in the above figure is support images, and for each query image, the attention heatmap and LRP heatmap are given. For correctly classified Q1 and Q3, the correctly labeled LRP heatmap highlights relevant features. In particular, the LRP heatmap can capture the window features of the bus as well as the head features of the malamute.

While other mislabeled LRP heatmaps show more negative evidence, we can still find similarities between the query image and the interpreted labels. For example, when we interpret the labels of Q3: malamute, the LRP heatmap highlights textures within circular structures.

5. Reference

Sun J, Lapuschkin S, Samek W, et al. Explanation-guided training for cross-domain few-shot classification[C]//2020 25th International Conference on Pattern Recognition (ICPR). IEEE, 2021: 7609-7616.

Guess you like

Origin blog.csdn.net/qq_44681809/article/details/128266793