Interpretation of Swin_Transformer source code


foreword

 This article records the part of swin_transformer that confuses me:relative position codeandSliding window self-attention. Thank you very much for your analysis: Zhihu link . This blog is only to analyze the source code on the basis of the Zhihu link. It is recommended that readers understand the above link analysis first.


1. The overall structure of the model

insert image description here
 The structural benchmark of swin_transformer is resnet, that is, every time the feature map passes through a stage, the size is doubled, and the number of channels is doubled. Briefly talk about the network process: Assuming that the input image size is 224 224 3, it is first converted into a feature map of 56 56 96 through the patch partion and Linear embedding modules, as shown in the above figureC=96; After passing through a stage, usepatch merging moduleTo reduce the size of the feature map and double the number of channels, and finally get 7 7 768 feature maps after 32 times downsampling .
 Next, I will analyze the source code of the more important modules according to the above process.

2、Patch Partion + Linear Embedding

 Assuming that the size of the input feature map is (224,224,3), the Patch Partition operation divides the feature map into patches of 4 4 3 sizes, namely (56,56,4,4,3) and then integrates the last three dimensions to obtain (56,56,48); and then map the last dimension 48 to 96 through Linear Embedding.
 But when actually writing code, the above process can be directly used96 kernels of size 4 with a stride of 4The convolution operation implementation:

patch_size = 4
proj = nn.Conv2d(in_chans=3, embed_dim=96, kernel_size=patch_size, stride=patch_size)
# forward
x = torch.randn(1,3,224,224)
x = proj(x).flatten(2).transpose(1, 2) # [1,96,56,56] --> [1,96,3136] --> [1,3136,96]

 Through the above operations, 3136 tokens are obtained, and each token contains 96-dimensional data.

3. Patch Merging

 Continue to follow the flow chart, after obtaining the characteristics of [1,3136,96], it will go through several stages, and the block inside each stage will be discussed later,Just know that the input and output sizes are the same after the block. For the convenience of the follow-up schematic diagram, it is assumed that after the stage2 is released, the features of [1,28*28,192] are obtained, and now the third stage is required, but the premise is to go through a Patch Merging operation:Reduce the size of the feature map by a factor of 2 while doubling the number of channels.See how the code implements the process:

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution  # 输入图像的分辨率
        self.dim = dim                            # 输入图像的通道数目
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) # 定义一个降采样倍的linear层
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({
      
      H}*{
      
      W}) are not even."

        x = x.view(B, H, W, C)  
		# 横纵两个方向奇偶不重叠切片
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        # 采样后则宽和高下采样两倍,通道数增加了4倍
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x) # 利用线性层将通道数压缩2倍

        return x

if __name__ == '__main__':
    
    x = torch.randn(1,784,192)
    patch_merge = PatchMerging(input_resolution=(28,28), dim = 192)
    x = patch_merge(x)  # [1,196,384],特征图tokens下采样四倍,通道数翻倍

 I basically made a detailed comment on the above code, and here is a schematic sketch to illustrate:
insert image description here
 Through the above operations, the max pooling operation similar to that in cnn is realized.

4. window attention

4.1 Divide windows

 After patch merging, the feature map of [1,14*14,384] is obtained. A major advantage of the swin transformer is to overcome the shortcomings of the high amount of calculation of self-attention, and the means of implementationIt is to divide the feature map into small windows, and then calculate self-attention between the tokens inside each window.Therefore, first look at how to divide the feature map into windows of size 7*7.

def window_partion(x, window_size):
    '''
    input:
        x: [b,h,w,c]
        window_size:论文指定为7
    return: window:[num_widows * B, window_size, window_size, c]
    '''
    b,h,w,c = x.shape
    x = x.view(b, h//window_size, window_size, w//window_size, window_size, c)
    windows = x.permute(0,1,3,2,4,5) # [b, h// window, w//window_size, window_size, window_size, c]
    windows = windows.contiguous().view(-1,window_size,window_size,c)
    return windows

 In order to see the cut out window more intuitively, I simply wrote a demo here: Assume that the feature map size is 14*14, window_size = 7, and print as follows:

# 创建一个14*14的特征图,则能够划分成4个7*7的window
# 为了方便展示,将每个window内部的value分别设为0,1,2,3
x = torch.zeros((14,14))
x[0:7,0:7] = 0
x[0:7,7:]  = 1
x[7:, 0:7] = 2
x[7:, 7:]  = 3
# 划分window
x = x.view( 14//7,7, 14//7, 7)
x = x.permute(0,2,1,3).contiguous().view(-1,7,7)
print(x)

insert image description here

4.2. Relative position encoding

  Now if the feature map of [1,14*14,384] has been divided into 4 windows of 7*7*384, before calculating self-attn for each window, according to the idea of ​​transformer, it is necessary to give each window internal The position is encoded, and the swin transformer uses relative position encoding. Simply put, the process is as follows:The size of a window is 7*7, then the relative positions of the coordinates of the upper left corner (0,0) and the coordinates of the lower right corner (6,6) are coded as (-6,-6) and (6,6). That is, the internal position of each window can be [-6, 6], a total of 13 values; if 2d position encoding is used, a total of 13*13 is required. Expressed in a formula: Assuming window_size = M, the value range of the relative position encoding is [-2M+1, 2M-1].Then look at the source code of this part of the implementation and create a table:

# 用nn.Parameter对(2M-1,2M-1)进行了封装,这部分参数是可学习的
self.relative_position_bias_table = nn.Parameter(
    torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 

 After creating the relative position encoding table, it is necessary to performEncoding of relative positions. Here I first paste the code:

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

insert image description here
 Confused? It's okay, let's also encode the relative position of the entire diagram. For the convenience of visualization, here is only a window with a size of 2*2 as an example, that is, the above window_size=(2,2). The following figure reflects the execution process of the above code.
insert image description here
 In actual use,Only need to index from the tabel according to the index of the relative position value of each feature map

relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
    self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH

4.3. window_attn

 In fact, the code of window atten is the same as that of self-attention, only one more dimension of window is added. Look at the code: the input is the feature map x passed through the window_partion function. Ignore the mask for now.

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # 尺度变换系数
        self.scale = qk_scale or head_dim ** -0.5
        # 创建三个qkv的全连接层
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
		#初始化
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape 
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
		# 获取相对位置编码并加到attn上
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)
	
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        # 经过softmax
        else:
            attn = self.softmax(attn)
		# 经过dropout
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

5. shift window attn

 Window attn only performs information interaction within each window, but there is no information interaction between windows, which is similar to the lack of CNN's shared weight process. So swin transformer designed the shift window attn module. As shown in the figure below, if you cut windows of different sizes, you only need to perform four attentions, but after cutting, it will increase to 9 windows, which will increase the amount of calculation.So an efficient batch attention is designed without increasing the amount of calculation, and the tool used is to use the mask.:
insert image description here
 Next, I will introduce how to achieve it. For the convenience of visualization, assume that the feature map is 4*4, window_size=(2,2).The code implementation is mainly divided into two parts: one is the feature map, and the other is the mask. As shown in the figure below, a two-step roll operation is performed on the feature map first:

x = torch.tensor([[[[1],[2],[3],[4]],[[5],[6],[7],[8]],[[9],[10],[11],[12]],[[13],[14],[15],[16]]]])

x = x.squeeze(-1)
print(x)
shift_size = 1   # 移动的距离是window_size的一半,即2//2=1
shifted_x = torch.roll(
    x,
    shifts=(-shift_size, -shift_size),
    dims=(1, 2))
print(shifted_x)

insert image description here
 Then look at the mask generation process.

Hp = Wp = 4
img_mask = torch.zeros((1, Hp, Wp, 1))   # 创建一个和原始特征图同样大小的全0mask

window_size = 2
shift_size = 2 // 2
# 切片
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1
# [1,4,4,1] --> [4,2,2,1]  
# num_window*b, 2,2,1,总共有四个window,每个window的坐标[2,2,1]
mask_windows = window_partition(img_mask, window_size)  # nW, window_size, 
# [4,2,2,1] --> [4,4]
mask_windows = mask_windows.view(-1, window_size * window_size) 
'''
拉平
tensor([[0., 0., 0., 0.],
        [1., 2., 1., 2.],
        [3., 3., 6., 6.],
        [4., 5., 7., 8.]])
'''
# [4,1,4] - [4,4,1] --> broadcast [4,4,4] - [4,4,4] = [4,4,4]
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
register_buffer("attn_mask", attn_mask)

 Through the above code, the four masks shown in the figure below are realized. I am a little confused about the internal logic of the above code, but I know that the following four masks can be generated.
insert image description here
  After having the feature map and mask, it is sent to the window_attention class to calculate the attention, but before the calculation, first calculate the attn matrix of the four windows.
insert image description here
 After having four attn, they can be added to the four masks separately to makeThe local information of mask==0 is reserved, and the local information of mask==-100 is leaked. For example, take attn2 and mask2 as examples:

x2 = torch.tensor([[8],[5],[12],[9]]) # [4,1]
attn2 = x2 @ x2.transpose(-1,-2)
'''
tensor([[64, 40, 96, 72],
        [40, 25, 60, 45],
        [96, 60, 144, 108],
        [72, 45, 108, 81]])
'''

In the original feature map, elements 8 and 5 belong to the same window, and should not calculate self-attention with 12 and 9, because the two are too far apart. For example, the semantics of the sky and the earth are far away, and there is no need for information interaction. Therefore, mask2 needs to be used to block the self-attention of these two elements. After the subsequent softmax, the position of -100 will become 0.
 Finally, after the attention is calculated, the feature map can be rolled back.

if self.shift_size > 0:
    x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

network

 Finally, the network stacks the above modules to obtain the swin transformer structure.
insert image description here

Summarize

 The interpretation of swinV2 will be released later. If you have any questions, please +vx: wulele2541612007 and pull them into the group discussion.

Guess you like

Origin blog.csdn.net/wulele2/article/details/124966785