引言:从ViT的困境到Swin的突破
Vision Transformer(ViT)虽在图像分类中表现出色,但其全局注意力机制导致计算复杂度与图像尺寸呈平方关系(O(n²)),难以处理高分辨率图像。2021年,Swin Transformer通过层级架构+滑动窗口的创新设计,首次让Transformer成为目标检测、分割等密集预测任务的通用骨干网络。本文将从数学推导、结构设计和代码实现三方面揭示其核心原理。
一、Swin Transformer核心创新
1.1 层级特征金字塔
通过四个阶段逐步下采样,输出多尺度特征图(如224×224→7×7),完美适配检测/分割任务:
- Stage1:将图像划分为4×4的Patch(每个Patch视为一个"词")
- Stage2~4:通过Patch Merging进行2倍下采样
1.2 滑动窗口注意力(SW-MSA)
将特征图划分为不重叠的局部窗口(如7×7),在窗口内计算自注意力:
\text{Attention}(Q,K,V) = \text{Softmax}(\frac{QK^T}{\sqrt{d}} + B)V
其中B
为相对位置编码。计算复杂度从O(n²)降至O(n)(n为图像像素数)。
二、滑动窗口的巧妙设计
2.1 窗口划分与循环移位
- 常规窗口划分:将特征图均匀切分为M×M窗口
- 移位窗口划分:将特征图循环移位(M/2, M/2)后重新划分,实现跨窗口信息交互
2.2 掩码机制
在移位窗口计算时,通过掩码矩阵屏蔽不同区域间的非法连接:
# 生成掩码矩阵
mask = torch.zeros((H, W))
mask[:M//2, :M//2] = 0 # 区域A
mask[M//2:, M//2:] = 1 # 区域B
masked_attention = attention + mask * -1e9
三、核心模块代码实现
3.1 窗口注意力模块
import torch
from torch import nn
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
# 相对位置编码表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2*window_size-1)**2, num_heads))
# 初始化坐标索引
coords = torch.stack(torch.meshgrid(
torch.arange(window_size),
torch.arange(window_size)), dim=0)
coords_flatten = coords.flatten(1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1,2,0).contiguous()
relative_coords[:, :, 0] += window_size - 1
relative_coords[:, :, 1] += window_size - 1
relative_coords[:, :, 0] *= 2 * window_size - 1
relative_index = relative_coords.sum(-1)
self.register_buffer("relative_index", relative_index)
def forward(self, x):
B, H, W, C = x.shape
x = x.view(B, H//self.window_size, self.window_size,
W//self.window_size, self.window_size, C)
x = x.permute(0,1,3,2,4,5).contiguous().view(-1, self.window_size*self.window_size, C)
# 计算注意力
qkv = self.qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.view(-1, self.num_heads, C//self.num_heads), qkv)
attn = (q @ k.transpose(-2,-1)) * self.scale
# 添加相对位置编码
relative_bias = self.relative_position_bias_table[self.relative_index.view(-1)]
relative_bias = relative_bias.view(self.window_size**2, self.window_size**2, -1)
attn += relative_bias.permute(2,0,1).unsqueeze(0)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1,2).reshape(-1, C)
return x
3.2 Patch Merging层
class PatchMerging(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(4*dim)
self.reduction = nn.Linear(4*dim, 2*dim)
def forward(self, x):
B, H, W, C = x.shape
x = x.view(B, H//2, 2, W//2, 2, C)
x = x.permute(0,1,3,2,4,5).contiguous()
x = x.view(B, -1, 4*C) # 合并相邻四个Patch
x = self.norm(x)
x = self.reduction(x) # 降维
return x
四、性能对比与实验分析
4.1 在ImageNet上的表现
模型 | Top-1 Acc | Params | FLOPs |
---|---|---|---|
ResNet-50 | 76.1% | 25M | 4.1G |
ViT-B/16 | 79.9% | 86M | 17.6G |
Swin-T | 81.3% | 29M | 4.5G |
4.2 在COCO目标检测上的表现
模型 | [email protected] | [email protected] |
---|---|---|
Mask R-CNN+Res50 | 41.0 | 37.8 |
Cascade Mask+Swin | 53.9 | 47.2 |
结论:Swin在保持较低计算量的同时,显著提升了下游任务性能。
五、实战技巧与优化方向
- 训练策略:
- 使用AdamW优化器(lr=1e-3,weight_decay=0.05)
- 余弦退火学习率调度
- 数据增强:MixUp + CutMix
- 显存优化:
- 梯度检查点(Gradient Checkpointing)
- 混合精度训练
- 部署优化:
- 转换为TensorRT引擎
- 使用窗口注意力融合技术
六、总结与展望
Swin Transformer通过三大创新奠定了其CV骨干网络地位:
- 层级结构:自然支持多尺度特征
- 滑动窗口:线性计算复杂度
- 移位窗口:实现跨窗口交互
未来方向:
- 探索动态窗口划分策略
- 与CNN的更深层次融合
- 面向移动端的轻量化设计
思考题:
- 为什么移位窗口需要配合掩码机制?
- 如何理解相对位置编码的物理意义?
下期预告:《Transformer在多模态中的应用:CLIP模型原理解析》
资源推荐:
(注:实验数据基于COCO 2017和ImageNet-1K数据集,完整复现需8×A100 GPU)