CVPR2021 attention mechanism: Coordinate Attention - source code

In one sentence, CA attention is: taking into account the positional relationship on the basis of channel attention, and combining the main channel attention with spatial attention. The SE module only considers spatial attention, and CBAM separates spatial attention and channel attention.

Paper link

Source link


        SE Module CBAM Module CA Module

class CA(nn.Module):
    def __init__(self, inp, reduction):
        super(CA, self).__init__()
        # h:height(行)   w:width(列)
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))  # (b,c,h,w)-->(b,c,h,1)
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))  # (b,c,h,w)-->(b,c,1,w)


         # mip = max(8, inp // reduction)  论文作者所用
        mip =  inp // reduction  # 博主所用   reduction = int(math.sqrt(inp))

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)  # (b,c,h,1)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)  # (b,c,w,1)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

Guess you like

Origin blog.csdn.net/Ratib/article/details/120867413