引言:当Transformer遇见图像
传统图像分类任务由CNN主导,但Transformer凭借其全局建模能力,在ImageNet等基准任务中刷新了记录。2020年,Vision Transformer(ViT)的提出标志着Transformer正式进军CV领域。本文将以实战为导向,详解如何用纯Transformer实现图像分类,并提供完整PyTorch代码实现。
一、ViT核心思想:图像即序列
ViT的核心创新在于将图像视为由Patch组成的序列,其处理流程分为四步:
- 图像分块:将输入图像(224×224)分割为16×16的Patch(共196个)
- 线性嵌入:将每个Patch展平为向量(16×16×3=768维)
- 添加位置编码:保留空间位置信息
- 输入Transformer编码器:提取全局特征
二、关键代码实现
2.1 Patch Embedding层
import torch
from torch import nn
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size,
stride=patch_size) # 用卷积实现分块
def forward(self, x):
x = self.proj(x) # [B, 768, 14, 14] (224/16=14)
x = x.flatten(2).transpose(1, 2) # [B, 196, 768]
return x
2.2 位置编码
ViT采用可学习的1D位置编码:
class ViT(nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.patch_embed = PatchEmbed()
self.pos_embed = nn.Parameter(torch.randn(1, 197, 768)) # 196+1个位置
# 添加分类token [class]
self.cls_token = nn.Parameter(torch.randn(1, 1, 768))
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # [B, 196, 768]
# 拼接[class] token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # [B, 197, 768]
# 添加位置编码
x += self.pos_embed
return x
三、Transformer编码器设计
ViT仅使用Encoder部分,包含交替的多头自注意力(MSA)和前馈网络(MLP):
class Block(nn.Module):
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, 3072),
nn.GELU(),
nn.Linear(3072, dim)
def forward(self, x):
# 残差连接1
x = x + self.attn(self.norm1(x))[0]
# 残差连接2
x = x + self.mlp(self.norm2(x))
return x
class ViTEncoder(nn.Module):
def __init__(self, depth=12):
super().__init__()
self.blocks = nn.ModuleList([Block() for _ in range(depth)])
def forward(self, x):
for blk in self.blocks:
x = blk(x)
return x
四、训练策略与技巧
4.1 混合精度训练
使用FP16加速训练,减少显存占用:
scaler = torch.cuda.amp.GradScaler()
with torch.cuda.amp.autocast():
output = model(inputs)
loss = criterion(output, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
4.2 学习率预热
前500步线性增加学习率,避免初期震荡:
def warmup_lr(step, warmup_steps=500, base_lr=1e-3):
return base_lr * min(step / warmup_steps, 1.0)
五、实战:CIFAR-10分类示例
5.1 数据预处理
调整图像尺寸并归一化:
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
5.2 模型微调
修改ViT分类头适配CIFAR-10:
model = ViT(num_classes=10)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
5.3 训练结果
在RTX 3090上训练50 epoch后:
训练集准确率:98.2%
测试集准确率:93.5%
六、ViT与CNN的对比分析
指标 | ViT-B/16 | ResNet-50 |
---|---|---|
Top-1 Acc | 84.15% | 76.15% |
参数量 | 86M | 25M |
训练数据需求 | >14M图像 | ~1M图像 |
推理速度 | 178ms/img | 76ms/img |
## 结论:
- ViT在大数据场景下表现更优
- CNN在小数据场景仍具优势
七、延展与优化方向
- 高效注意力:使用Swin Transformer的窗口注意力降低计算量
- 混合架构:CNN+Transformer结合(如ResNet50-ViT)
- 知识蒸馏:用大模型指导小模型训练
# 示例:Swin Transformer块
class SwinBlock(nn.Module):
def __init__(self, dim, window_size=7):
super().__init__()
self.w_msa = WindowMSA(dim, window_size)
self.mlp = nn.Sequential(...)
八、总结
ViT的成功证明了Transformer在CV领域的强大潜力,但其应用仍需注意:
- 数据量要求:建议在>100万图像的数据集上训练
- 计算资源:需要GPU集群支持预训练
- 位置编码:对图像旋转敏感,需设计更鲁棒的编码方案
下期预告:《Swin Transformer原理详解:让Transformer真正成为视觉通用骨干》
资源下载:
(注:实验数据基于PyTorch 1.10 + CUDA 11.3环境,完整复现需调整超参数)