Mécanisme d'attention dans les réseaux de neurones profonds

Le mécanisme dit d'attention est un mécanisme qui se concentre sur des informations locales, telles qu'une certaine zone d'image dans une image. Les régions attentionnelles ont tendance à changer à mesure que les tâches changent. Le mécanisme d'attention est une technique couramment utilisée dans l'apprentissage en profondeur. Lorsque nous utilisons un réseau de neurones convolutifs pour traiter des images, nous espérons que le réseau de neurones convolutifs pourra prêter attention à des endroits relativement importants. Le mécanisme d'attention est un moyen de réaliser l'attention adaptative du réseau.

Les mécanismes d'attention peuvent généralement être divisés en attention de canal et attention spatiale, ou une combinaison des deux.

1. Canal Attention (SENet)

L'attention du canal se soucie du poids de chaque canal de fonctionnalité, ce qui permet au réseau de se concentrer sur le canal auquel il doit le plus prêter attention. La figure ci-dessus montre ses étapes de mise en œuvre spécifiques, et il y a 6 étapes comme suit :

1) Effectuer une mise en commun globale (maximum/moyenne) sur la carte des caractéristiques d'entrée de HxWxC, et une bande de caractéristiques 1x1xC peut être obtenue, et la longueur est le nombre de canaux C de la carte des caractéristiques d'entrée ;

2) Effectuez ensuite une connexion complète, réduisez la dimension à la dimension C/r avec moins de neurones et obtenez un vecteur 1x1XC/r ;

3) Fonction d'activation ReLu ;

4) La deuxième connexion complète restaure la longueur du vecteur à la même valeur qu'auparavant, c'est-à-dire le vecteur de 1x1xC ;

5) La fonction sigmoïde normalise chaque poids entre 0 et 1 ;

6) Chaque canal de la carte de caractéristiques d'entrée est multiplié par le poids de chaque canal pour obtenir une nouvelle carte de caractéristiques.

Le code d'implémentation est le suivant (la mise en commun maximale et moyenne est utilisée dans le code, et finalement ajoutée) :

 

import torch
import torch.nn as nn
import torch.utils.data as Data


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x): # x 的输入格式是:[B, C, H, W]
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

2. Attention spatiale

L'attention spatiale se soucie du poids de chaque pièce sur la surface.Les étapes de mise en œuvre spécifiques sont les suivantes :

1) Pour chaque position de la carte de caractéristiques d'entrée HxWxC, prenez la valeur maximale de tous les canaux, et une carte de caractéristiques HxWx1 peut être obtenue ;

2) Pour chaque position de la carte des caractéristiques d'entrée de HxWxC, prenez la valeur moyenne de tous les canaux et obtenez également une carte des caractéristiques de HxWx1 ;

3) Empiler (concaténer) les deux cartes de caractéristiques pour obtenir la carte de caractéristiques de HxWx2, effectuer une convolution pour ajuster le nombre de canaux à 1 et obtenir la carte de caractéristiques de HxWx1 ;

4) La fonction sigmoïde normalise le poids de chaque position entre 0 et 1 ;

5) Chaque position de la carte d'entités en entrée est multipliée par le poids de chaque position pour obtenir une nouvelle carte d'entités.

Le code d'implémentation est le suivant :

import torch
import torch.nn as nn
import torch.utils.data as Data

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x): # x 的输入格式是:[B, C, H, W]
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

 

3. Réalisation de CBAM

CBAM combine le mécanisme d'attention de canal et le mécanisme d'attention spatiale, qui peut obtenir de meilleurs résultats que le mécanisme d'attention de SENet qui se concentre uniquement sur les canaux. Le diagramme schématique de sa mise en œuvre est présenté ci-dessous.CBAM traitera le mécanisme d'attention du canal et le mécanisme d'attention spatiale pour la couche d'entités d'entrée.

Disposition des modules d'attention : étant donné une image d'entrée, deux modules d'attention, canal et espace, calculent l'attention complémentaire, en se concentrant sur "quoi" et "où", respectivement. Dans cette optique, les deux modules peuvent être placés de manière parallèle ou séquentielle. Les auteurs de l'article original ont découvert que les permutations séquentielles donnaient de meilleurs résultats que les permutations parallèles.

Le code d'implémentation est le suivant :

 

class cbam_block(nn.Module):
    def __init__(self, channel, ratio=8, kernel_size=7):
        super(cbam_block, self).__init__()
        self.channelattention = ChannelAttention(channel, ratio=ratio)
        self.spatialattention = SpatialAttention(kernel_size=kernel_size)

    def forward(self, x):
        x = x * self.channelattention(x)
        x = x * self.spatialattention(x)
        return x

Je suppose que tu aimes

Origine blog.csdn.net/wanchengkai/article/details/128716244
conseillé
Classement