[论文阅读笔记11]Swin-Transformer

0. 前言

Swin Transformer, 即Shift Window Transformer, 它旨在让Transformer结构跟CNN一样, 也可以作为骨干网络在各种计算机视觉任务中来使用, 以及解决ViT计算复杂度高的问题.

具体地, 在CNN网络中, 大多都是层级结构. 比如说, 每一层都让高宽减半, 通道数增加. 这样在每一层都具有不同的感受野, 进而获得不同尺度的语义信息. 那么Transformer结构可不可以也这么做呢? Swin-Transformer就是这样一个工作.

1. Swin-Transformer工作流程

Swin-T工作流程图如下:(论文figure3(a) (b))
在这里插入图片描述
根据前向过程捋一捋. 从左到右:

1.1 输入, 以及Patch Partition + Linear Embedding

在这里插入图片描述

假设图像大小为 ( 224 , 224 , 3 ) (224,224,3) (224,224,3), 以 4 × 4 4\times4 4×4像素大小为一个patch, 这样图像的高宽就被切成了 56 , 56 56,56 56,56, 把每个patch的通道维合起来, 通道维就变成了 4 × 4 × 3 = 48 4\times4\times3=48 4×4×3=48, 因此得到的张量维度为 ( 56 , 56 , 48 ) (56,56,48) (56,56,48).
随后经过Linear Embedding, 高宽不变, 将维度变为特定值. 对于Swin-T这个小模型来说, 这个值为96. 因此之后的维度变为 ( 56 , 56 , 96 ) (56,56,96) (56,56,96).
这一步的大概思想跟ViT是一致的.

代码是这样实现的:

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding

    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Forward function."""
        # padding
        _, _, H, W = x.size()
        # 如果大小不整除就padding
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        # 通过卷积操作来完成维度转换 卷积输入维度为3 输出维度为96(Swin-T), 
        # kernel size是4, stride size也是4. 这样输出维度变为(bs, 96, H/4, W/4)
        x = self.proj(x)  # B C Wh Ww  
        # Layer norm操作
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)

        return x

注意: 在代码中Patch Partition + Linear Embedding是一次实现的, 通过卷积直接将输入图像维度变为 ( b s , 96 , 56 , 56 ) (bs, 96, 56, 56) (bs,96,56,56). 随后将后两个维度展平:

		Wh, Ww = x.size(2), x.size(3)
        if self.ape:
            # interpolate the position embedding to the corresponding size
            absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
        else:
            x = x.flatten(2).transpose(1, 2)  # bs, 56*56, 96
        x = self.pos_drop(x)

1.2 Stage1: 2*Block

在这里插入图片描述

在第一步得到 ( 56 , 56 , 96 ) (56,56,96) (56,56,96)维度张量并拉直为 ( 96 , 56 ∗ 56 = 3136 ) (96, 56*56=3136) (96,5656=3136)后, 经过block后, 维度不变, 仍为 ( 56 , 56 , 96 ) (56,56,96) (56,56,96).

block中具体内容见后文.

1.3 stage2, 3, 4: Patch Emerging + n*blocks

在这里插入图片描述

对于stage2, 3, 4, 流程和stage1差不多, 只不过经过blocks的数目不一样. 由此看出, Swin-Transformer的框架和CNN真的很像, 这就是它层级式的结构, 和卷积层异曲同工.

Patch Merging做的操作类似于CNN中的池化, 其将高宽减半, 维度增加. 而后经过block后维度又不变, 因此一个stage就相当于一个kernel size=3 padding=1 stride=1的卷积层 + stride=2的池化层.

Patch Merging具体的做法类似于空洞卷积, 也就是将张量的高宽分为 4 × 4 = 16 4\times4=16 4×4=16个部分, 每隔一个部分凑一起, 最后将维度拼接. 由于是构成了4个部分, 设原来的维度是 c c c, 则拼接后维度为 4 c 4c 4c. 为了和卷积层维度加倍保持一致, 再用一个线性层将 4 c 4c 4c映射到 2 c 2c 2c.

Patch Merging过程如下图所示. 图中1, 2, 3, 4表示相应的部分.
在这里插入图片描述

代码:

class PatchMerging(nn.Module):
    """ Patch Merging Layer

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C 在H, W维隔一个取一个, 即图中四个1部分
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C 图中四个2部分
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C 图中四个3部分
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C 图中四个4部分
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)  # 线性层 将4c降为2c

        return x

2. block: 窗口自注意力和移动窗口自注意力

我们拿stage1举例子. 在stage1的block之前, 我们得到了维度为 56 , 56 , 96 56,56,96 56,56,96的张量. 如果想降低计算复杂度, 就想缩小计算自注意力的范围, 而不是像ViT那样整张图都算注意力. 为此, 提出另一个计算单元: window(注意, 不是patch).

例如, 我们把 56 , 56 , 96 56,56,96 56,56,96的张量在H,W维分成 7 × 7 7\times7 7×7的window, 则每个window的size是 7 × 7 × 96 7\times7\times96 7×7×96. 整个张量就分成了 8 × 8 = 56 8\times8=56 8×8=56个window. 随后在每个window内计算自注意力, 这样就降低了复杂度.

但是, window和window之间不计算注意力, 不就割裂开了吗? 这样的割裂也违反了Transformer设计的初衷. 因此, 提出了滑动窗口注意力机制. 直观上讲, 就是将window"变形"一下, 相当于拿一个移动的模板:
在这里插入图片描述
这样一移动, window和window之间就有联系了. 如上图所示, 为了举例子方便, 假设有4个window. 在方方正正的window算完自注意力之后, 将窗口移动一下, 向右边的图所示. 但是移动之后就变成了9个window, 这种数目的变化会造成计算效率的降低. 那么如何将移动后的9个window也变得和4个window一样计算呢?

做法是掩码机制. 直觉上说, 移动后的window仍然按照移动前的计算, 只不过加入掩码后, 可以将原本的计算等效成右图. 具体细节待补充.

因此, 整个stage的block部分过程包括一个窗口自注意力的block和滑动窗口自注意力的block. 如此, 方能在降低计算复杂度的同时实现窗口的联系. 如下图所示:
在这里插入图片描述
因此, 每个stage的block数目都是偶数.

代码:
将张量分成windows的代码:

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape  # shape为bs, h, w, c
    # 将高宽除以window_size(例如7), 将x展成(bs, h/7, 7, w/7, 7, c)
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)

    # 首先将维度换位置(permute方法), 变成(bs, h/7, w/7, 7, 7, c)
    # 然后用contiguous方法拷贝了一份张量在内存中的地址 然后将地址按照形状改变后的张量的语义进行排列
    # view改变维度, 为(bs*h/7*w/7, 7, 7, c)
    # bs*h/4*w/4就是后面要用的窗口的个数(nW变量, number of windows)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

一个基本block块的前向传播代码:

    def forward(self, x, mask_matrix):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
            mask_matrix: Attention mask for cyclic shift.
        """
        B, L, C = x.shape
        H, W = self.H, self.W  # 高 宽
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)  # Layer norm层
        x = x.view(B, H, W, C)  # bs, h, w, c

        # 做padding
        # pad feature maps to multiples of window size
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            attn_mask = mask_matrix
        else:
            shifted_x = x
            attn_mask = None

        # partition windows
        # 将张量分成windows 维度为windows数目*bs, 7, 7, c
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        # 将window reshape成二维矩阵以便输入到transformer
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        # 自注意力 窗口自注意力和滑动窗口自注意力交叉进行
        # mask用以实现并行计算
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        # 恢复成windows数目*bs, 7, 7, c维度
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        # 将计算结果reshape成分成窗口之前的维度形式 即
        # bs, H, W, C
        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, C)

        # FFN
        # 前向传播
        x = shortcut + self.drop_path(x)  # 残差连接 输入加窗口计算结果
        x = x + self.drop_path(self.mlp(self.norm2(x)))  # 残差连接 加layer norm后经过MLP的结果

        return x

其中用于转换维度的window_reverse函数:

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))  # bs = 第一个维度 / 窗口个数
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)  # bs, h/7, w/7, 7, 7, c
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)  # bs, h, w, c
    return x

3.复杂度分析

比较一下窗口自注意力和ViT这种全局自注意力复杂度的差别.
对于一次, 一个head的自注意力来说, 假设输入为 x ∈ R h w × c x\in\mathbb R^{hw\times c} xRhw×c, x x x要和三个矩阵相乘得到 q , k , v q,k,v q,k,v. 也即:
q = x W q , W q ∈ R c × c k = x W k , W k ∈ R c × c v = x W v , W v ∈ R c × c q=xW_q, W_q\in \mathbb R^{c\times c}\\ k=xW_k, W_k\in \mathbb R^{c\times c}\\ v=xW_v, W_v\in \mathbb R^{c\times c} q=xWq,WqRc×ck=xWk,WkRc×cv=xWv,WvRc×c
因此计算 q , k , v q,k,v q,k,v需要 3 × ( h w ) × c 2 3\times(hw)\times c^2 3×(hw)×c2次乘法(相乘后有 h w × c hw\times c hw×c个元素, 每个元素需要经过c次乘法)

接下来计算自注意力矩阵
A = q k T ∈ R h w × h w A=qk^T\in\mathbb R^{hw\times hw} A=qkTRhw×hw
A A A的计算需要 ( h w ) 2 c (hw)^2c (hw)2c次乘法.

之后将 A , v A, v A,v相乘计算注意力分数, 需要 h w × c × h w = ( h w ) 2 c hw\times c \times hw=(hw)^2c hw×c×hw=(hw)2c次乘法.

A , v A, v A,v相乘后得到维度 h w × c hw\times c hw×c的向量, 需要经过一个线性层. 线性层相当于矩阵乘法, 维度从 h w × c → h w × c hw\times c \to hw\times c hw×chw×c, 需要 ( h w ) × c 2 (hw)\times c^2 (hw)×c2次乘法.

综上, 总共需要的乘法次数(等价于时间复杂度)为 3 × ( h w ) × c 2 + ( h w ) 2 c + ( h w ) 2 c + ( h w ) × c 2 = 4 h w c 2 + 2 ( h w ) 2 c = O ( ( h w ) 2 ) 3\times(hw)\times c^2+(hw)^2c+(hw)^2c+(hw)\times c^2=4hwc^2+2(hw)^2c=O((hw)^2) 3×(hw)×c2+(hw)2c+(hw)2c+(hw)×c2=4hwc2+2(hw)2c=O((hw)2).

对于(滑动)窗口自注意力, 由于每次只在一个窗口内作自注意力, 因此上面公式中的 h , w h,w h,w可先替换成窗口的 h 0 , w 0 h_0,w_0 h0,w0. 一共有 h w / h 0 w 0 hw/h_0w_0 hw/h0w0个窗口, 一次(滑动)窗口自注意力的计算复杂度为:

h w h 0 w 0 [ 4 h 0 w 0 c 2 + 2 ( h 0 w 0 ) 2 c ] = 4 ( h w ) c 2 + 2 h 0 w 0 ( h w ) c = O ( h w ) \frac{hw}{h_0w_0}[4h_0w_0c^2+2(h_0w_0)^2c]\\ =4(hw)c^2+2h_0w_0(hw)c=O(hw) h0w0hw[4h0w0c2+2(h0w0)2c]=4(hw)c2+2h0w0(hw)c=O(hw)

所以,(滑动)窗口自注意力的复杂度是关于张量维度的线性复杂度, 而ViT中的全局自注意力是二次复杂度, 根据论文所说, Swin-T的复杂度与ResNet50相当.

猜你喜欢

转载自blog.csdn.net/wjpwjpwjp0831/article/details/123974893