结合CNN和Vision Transformer(ViT)可以通过多种方法实现两者的优势互补,以下是一些具体方案及实现步骤:
1. 混合架构(Hybrid Architecture)
核心思想:将CNN作为局部特征提取器,ViT处理全局依赖关系。
实现方式:
-
前端CNN + 后端ViT
-
使用CNN(如ResNet、EfficientNet)提取图像特征图。
-
将特征图展平为序列,输入ViT进行全局建模。
-
输出分类/检测结果。
-
-
后端CNN + 前端ViT
-
使用ViT分割图像为Patch,生成全局特征。
-
通过转置卷积或插值恢复空间分辨率,输入CNN细化细节。
适用场景:图像分割、超分辨率等需要高分辨率输出的任务。
-
2. 并行结构(Parallel Branches)
核心思想:同时运行CNN和ViT分支,融合两者的特征。
实现方式:
-
特征拼接/加权融合
-
并行计算CNN和ViT的特征。
-
将特征图拼接或通过注意力机制融合。
-
class ParallelModel(nn.Module):
def __init__(self):
super().__init__()
self.cnn_branch = resnet18(pretrained=True)
self.vit_branch = ViT(image_size=224, patch_size=16)
self.fusion = nn.Linear(512 + 768, 1000) # 融合特征
def forward(self, x):
cnn_feat = self.cnn_branch(x) # [B, 512]
vit_feat = self.vit_branch(x) # [B, 768]
fused = torch.cat([cnn_feat, vit_feat], dim=1)
return self.fusion(fused)
3. 分阶段设计(Stage-wise Design)
核心思想:在不同阶段交替使用CNN和ViT模块。
实现方式:
-
浅层CNN + 深层ViT
-
浅层用CNN提取低级特征(边缘、纹理)。
-
深层用ViT建模高级语义关系。
参考模型:CoAtNet(CNN+Transformer混合堆叠)。
-
-
局部窗口注意力
在ViT中引入卷积操作,例如:-
MobileViT:用卷积处理局部窗口,再用Transformer跨窗口交互。
-
Swin Transformer:局部窗口自注意力 + CNN式层级下采样。
-
4. 卷积增强的Transformer模块
核心思想:在ViT中嵌入卷积层,增强局部感知。
实现方式:
-
在Patch Embedding中替换为卷积
使用卷积层代替ViT的线性投影生成Patch Embedding。
class ConvStem(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 768, kernel_size=16, stride=16) # 等效于Patch分割
def forward(self, x):
x = self.conv(x) # [B, 768, H/16, W/16]
x = x.flatten(2).transpose(1, 2) # [B, N, C]
return x
- 在Transformer Block中加入卷积
每个Transformer Block后添加深度可分离卷积:
class ConvolutionalTransformerBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.attn = MultiHeadAttention(dim)
self.conv = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) # 深度可分离卷积
def forward(self, x):
x = self.attn(x)
x = x.permute(0, 2, 1).view(B, C, H, W) # 恢复为图像格式
x = self.conv(x)
x = x.view(B, C, -1).permute(0, 2, 1)
return x
5. 自适应特征融合(Adaptive Fusion)
核心思想:动态调整CNN和ViT的贡献权重。
实现方式:
-
门控机制(Gating Network)
通过可学习参数自动选择CNN和ViT的特征:
class AdaptiveFusion(nn.Module):
def __init__(self, cnn_dim, vit_dim):
super().__init__()
self.gate = nn.Linear(cnn_dim + vit_dim, 2) # 学习权重
def forward(self, cnn_feat, vit_feat):
combined = torch.cat([cnn_feat, vit_feat], dim=1)
weights = F.softmax(self.gate(combined), dim=1) # [B, 2]
return weights[:, 0] * cnn_feat + weights[:, 1] * vit_feat
6. 轻量化设计(针对计算效率)
适用场景:资源受限时,结合MobileNet等轻量CNN与蒸馏后的ViT。
实现方式:
-
知识蒸馏:用大型CNN或ViT作为教师模型,训练混合结构的小模型。
-
Neck网络设计:仅替换模型的一部分(如将ResNet的最后阶段替换为Transformer)。
实践建议
-
预训练权重利用:优先加载CNN部分的预训练权重(如ImageNet预训练ResNet),ViT部分可从头训练或微调。
-
位置编码适配:若使用ViT处理CNN特征,需重新设计位置编码(如可学习的2D位置编码)。
-
下游任务适配:
-
分类任务:全局池化后接全连接层。
-
检测任务:保留空间维度,输出特征金字塔。
-
-
参数量平衡:避免CNN或ViT某一方过于庞大,导致模型失衡。
经典论文参考
-
CoAtNet(CNN + Transformer混合堆叠)
-
MobileViT(轻量级CNN与局部全局Transformer)
-
BoTNet(ResNet中的Bottleneck替换为自注意力)
-
CMT(CNN与Transformer并行分支)
以下是一个实现 多级CNN特征融合的ViT 的完整代码方案,基于PyTorch框架,结合了ResNet的多级特征和Vision Transformer的全局建模能力:
非最终代码!!
import torch
import torch.nn as nn
from torchvision.models import resnet50
from timm.models.vision_transformer import VisionTransformer
# 多级特征适配器模块
class MultiStageAdapter(nn.Module):
def __init__(self, cnn_channels=[256, 512, 1024], vit_dim=768, fusion_dim=256):
super().__init__()
# 为每个CNN阶段定义适配器
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Conv2d(ch, vit_dim, kernel_size=1), # 通道对齐
nn.BatchNorm2d(vit_dim),
nn.GELU(),
nn.Conv2d(vit_dim, vit_dim, 3, padding=1) # 空间信息保持
) for ch in cnn_channels
])
# 特征融合模块
self.fusion = nn.Sequential(
nn.Conv2d(vit_dim*len(cnn_channels), fusion_dim, 1),
nn.ReLU(inplace=True),
nn.Conv2d(fusion_dim, vit_dim, 3, padding=1)
)
# 上采样器(用于对齐不同尺度的特征)
self.upsamplers = nn.ModuleList([
nn.Upsample(scale_factor=2**i, mode='bilinear', align_corners=False)
for i in range(len(cnn_channels)-1, 0, -1)
])
def forward(self, features):
"""
features: list of CNN特征,从浅到深排序
[stage1_feat, stage2_feat, stage3_feat]
"""
# 处理每个适配器
adapted_feats = []
for feat, adapter in zip(features, self.adapters):
x = adapter(feat)
adapted_feats.append(x)
# 上采样对齐尺寸(以最深层特征为基准)
target_size = adapted_feats[-1].shape[-2:]
for i in range(len(adapted_feats)-1):
adapted_feats[i] = self.upsamplers[i](adapted_feats[i])
adapted_feats[i] = torch.nn.functional.interpolate(
adapted_feats[i], size=target_size, mode='bilinear')
# 通道维度拼接
fused = torch.cat(adapted_feats, dim=1) # [B, vit_dim*3, H, W]
# 融合后处理
return self.fusion(fused) # [B, vit_dim, H, W]
# 多级特征ViT主模型
class MultiScaleFusionViT(nn.Module):
def __init__(self, num_classes=1000, vit_model=None):
super().__init__()
# 1. CNN特征提取器(ResNet50前三个stage)
resnet = resnet50(pretrained=True)
self.cnn_stages = nn.ModuleDict({
'stem': nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool),
'stage1': resnet.layer1, # stride 4
'stage2': resnet.layer2, # stride 8
'stage3': resnet.layer3 # stride 16
})
# 2. 多级特征适配器
self.adapter = MultiStageAdapter(
cnn_channels=[256, 512, 1024], # ResNet各stage输出通道
vit_dim=768,
fusion_dim=512
)
# 3. 冻结的ViT骨干
if vit_model is None:
self.vit = VisionTransformer(
img_size=14, # 适配器输出特征图尺寸
patch_size=1, # 每个"patch"对应1x1的特征
in_chans=768,
num_classes=num_classes
)
else:
self.vit = vit_model
self._freeze_vit()
# 4. CLS Token和位置编码(保持与ViT兼容)
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768)) # 14x14=196 +1
def _freeze_vit(self):
# 冻结ViT所有参数
for param in self.vit.parameters():
param.requires_grad = False
# 可选:解冻LayerNorm和位置编码
for name, module in self.vit.named_modules():
if isinstance(module, nn.LayerNorm):
for param in module.parameters():
param.requires_grad = True
self.pos_embed.requires_grad = True
def forward(self, x):
# 步骤1:提取多级CNN特征
cnn_features = []
x = self.cnn_stages['stem'](x)
x = self.cnn_stages['stage1'](x) # [B,256,56,56]
cnn_features.append(x)
x = self.cnn_stages['stage2'](x) # [B,512,28,28]
cnn_features.append(x)
x = self.cnn_stages['stage3'](x) # [B,1024,14,14]
cnn_features.append(x)
# 步骤2:多级特征融合
vit_feat = self.adapter(cnn_features) # [B,768,14,14]
# 步骤3:准备ViT输入
B, C, H, W = vit_feat.shape
vit_feat = vit_feat.flatten(2).transpose(1, 2) # [B,196,768]
# 添加CLS Token
cls_tokens = self.cls_token.expand(B, -1, -1)
vit_feat = torch.cat([cls_tokens, vit_feat], dim=1) # [B,197,768]
# 添加位置编码
vit_feat += self.pos_embed
# 步骤4:通过冻结的ViT
vit_feat = self.vit.blocks(vit_feat)
vit_feat = self.vit.norm(vit_feat)
# 分类头
return self.vit.head(vit_feat[:, 0])
# 使用示例
if __name__ == '__main__':
# 加载预训练ViT(示例使用timm库)
from timm.models import vit_base_patch16_224
pretrained_vit = vit_base_patch16_224(pretrained=True)
# 创建多级融合模型
model = MultiScaleFusionViT(
num_classes=1000,
vit_model=pretrained_vit
)
# 打印可训练参数
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params/1e6:.2f}M")
# 测试前向
x = torch.randn(2, 3, 224, 224)
output = model(x)
print(f"Output shape: {output.shape}") # [2,1000]
代码关键设计解析
-
多级特征提取:
-
使用ResNet50的前三个stage(stem+layer1-3)
-
获取不同尺度的特征图:
-
Stage1: [256, 56, 56]
-
Stage2: [512, 28, 28]
-
Stage3: [1024, 14, 14]
-
-
-
特征适配与融合:
-
通道对齐:通过1x1卷积统一各stage特征到ViT的嵌入维度(768)
-
空间对齐:使用双线性插值上采样浅层特征,与深层特征尺寸对齐
-
特征融合:通道维度拼接后通过卷积层融合多尺度信息
-
-
ViT输入处理:
-
将融合后的特征展平为序列
-
添加可学习的CLS Token和位置编码
-
通过冻结的ViT Blocks进行特征转换
-
-
训练优化:
-
仅训练CNN特征提取器、适配器模块和位置编码
-
冻结ViT主体参数(可选解冻LayerNorm)
-
特征融合可视化
ResNet特征金字塔:
Stage1 ────┐ (56x56)
Stage2 ───┤ (28x28)
Stage3 ───┘ (14x14)
│
▼
[多级适配器:通道对齐+上采样]
│
▼
特征融合(拼接+卷积)
│
▼
ViT输入序列
效果增强技巧
-
动态特征加权:
# 在MultiStageAdapter中添加注意力权重
class ChannelAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels//16),
nn.ReLU(),
nn.Linear(in_channels//16, in_channels),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.gap(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
# 在适配器中插入注意力
self.adapters = nn.ModuleList([
nn.Sequential(
...,
ChannelAttention(vit_dim), # 新增通道注意力
...
) for ch in cnn_channels
])
-
跨阶段跳跃连接:
# 修改适配器前向传播
def forward(self, features):
adapted_feats = []
prev_feat = None
for i, (feat, adapter) in enumerate(zip(features, self.adapters)):
# 与前一阶段特征融合
if prev_feat is not None:
feat = feat + F.interpolate(prev_feat, scale_factor=0.5)
x = adapter(feat)
adapted_feats.append(x)
prev_feat = x
...
-
多尺度位置编码:
# 为每个尺度添加独立的位置编码
self.stage_pos_embeds = nn.ParameterList([
nn.Parameter(torch.randn(1, 768, 56, 56)),
nn.Parameter(torch.randn(1, 768, 28, 28)),
nn.Parameter(torch.randn(1, 768, 14, 14))
])
# 在前向传播中添加
for i, feat in enumerate(adapted_feats):
feat = feat + self.stage_pos_embeds[i]