【即插即用涨点模块】单头自注意力模块SHSA:避免冗余,高效计算,性能提升利器【附源码+注释】

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称 项目名称
1.【人脸识别与管理系统开发 2.【车牌识别与自动收费管理系统开发
3.【手势识别系统开发 4.【人脸面部活体检测系统开发
5.【图片风格快速迁移软件开发 6.【人脸表表情识别系统
7.【YOLOv8多目标识别与自动标注软件开发 8.【基于深度学习的行人跌倒检测系统
9.【基于深度学习的PCB板缺陷检测系统 10.【基于深度学习的生活垃圾分类目标检测系统
11.【基于深度学习的安全帽目标检测系统 12.【基于深度学习的120种犬类检测与识别系统
13.【基于深度学习的路面坑洞检测系统 14.【基于深度学习的火焰烟雾检测系统
15.【基于深度学习的钢材表面缺陷检测系统 16.【基于深度学习的舰船目标分类检测系统
17.【基于深度学习的西红柿成熟度检测系统 18.【基于深度学习的血细胞检测与计数系统
19.【基于深度学习的吸烟/抽烟行为检测系统 20.【基于深度学习的水稻害虫检测与识别系统
21.【基于深度学习的高精度车辆行人检测与计数系统 22.【基于深度学习的路面标志线检测与识别系统
23.【基于深度学习的智能小麦害虫检测识别系统 24.【基于深度学习的智能玉米害虫检测识别系统
25.【基于深度学习的200种鸟类智能检测与识别系统 26.【基于深度学习的45种交通标志智能检测与识别系统
27.【基于深度学习的人脸面部表情识别系统 28.【基于深度学习的苹果叶片病害智能诊断系统
29.【基于深度学习的智能肺炎诊断系统 30.【基于深度学习的葡萄簇目标检测系统
31.【基于深度学习的100种中草药智能识别系统 32.【基于深度学习的102种花卉智能识别系统
33.【基于深度学习的100种蝴蝶智能识别系统 34.【基于深度学习的水稻叶片病害智能诊断系统
35.【基于与ByteTrack的车辆行人多目标检测与追踪系统 36.【基于深度学习的智能草莓病害检测与分割系统
37.【基于深度学习的复杂场景下船舶目标检测系统 38.【基于深度学习的农作物幼苗与杂草检测系统
39.【基于深度学习的智能道路裂缝检测与分析系统 40.【基于深度学习的葡萄病害智能诊断与防治系统
41.【基于深度学习的遥感地理空间物体检测系统 42.【基于深度学习的无人机视角地面物体检测系统
43.【基于深度学习的木薯病害智能诊断与防治系统 44.【基于深度学习的野外火焰烟雾检测系统
45.【基于深度学习的脑肿瘤智能检测系统 46.【基于深度学习的玉米叶片病害智能诊断与防治系统
47.【基于深度学习的橙子病害智能诊断与防治系统 48.【基于深度学习的车辆检测追踪与流量计数系统
49.【基于深度学习的行人检测追踪与双向流量计数系统 50.【基于深度学习的反光衣检测与预警系统
51.【基于深度学习的危险区域人员闯入检测与报警系统 52.【基于深度学习的高密度人脸智能检测与统计系统
53.【基于深度学习的CT扫描图像肾结石智能检测系统 54.【基于深度学习的水果智能检测系统
55.【基于深度学习的水果质量好坏智能检测系统 56.【基于深度学习的蔬菜目标检测与识别系统
57.【基于深度学习的非机动车驾驶员头盔检测系统 58.【太基于深度学习的阳能电池板检测与分析系统
59.【基于深度学习的工业螺栓螺母检测 60.【基于深度学习的金属焊缝缺陷检测系统
61.【基于深度学习的链条缺陷检测与识别系统 62.【基于深度学习的交通信号灯检测识别
63.【基于深度学习的草莓成熟度检测与识别系统 64.【基于深度学习的水下海生物检测识别系统
65.【基于深度学习的道路交通事故检测识别系统 66.【基于深度学习的安检X光危险品检测与识别系统
67.【基于深度学习的农作物类别检测与识别系统 68.【基于深度学习的危险驾驶行为检测识别系统
69.【基于深度学习的维修工具检测识别系统 70.【基于深度学习的维修工具检测识别系统
71.【基于深度学习的建筑墙面损伤检测系统 72.【基于深度学习的煤矿传送带异物检测系统
73.【基于深度学习的老鼠智能检测系统 74.【基于深度学习的水面垃圾智能检测识别系统
75.【基于深度学习的遥感视角船只智能检测系统 76.【基于深度学习的胃肠道息肉智能检测分割与诊断系统
77.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统 78.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统
79.【基于深度学习的果园苹果检测与计数系统 80.【基于深度学习的半导体芯片缺陷检测系统
81.【基于深度学习的糖尿病视网膜病变检测与诊断系统 82.【基于深度学习的运动鞋品牌检测与识别系统
83.【基于深度学习的苹果叶片病害检测识别系统 84.【基于深度学习的医学X光骨折检测与语音提示系统
85.【基于深度学习的遥感视角农田检测与分割系统 86.【基于深度学习的运动品牌LOGO检测与识别系统
87.【基于深度学习的电瓶车进电梯检测与语音提示系统 88.【基于深度学习的遥感视角地面房屋建筑检测分割与分析系统
89.【基于深度学习的医学CT图像肺结节智能检测与语音提示系统

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

在这里插入图片描述

论文地址:https://arxiv.org/pdf/2401.16456
代码地址:https://github.com/ysj9909/SHViT

创新点

在这里插入图片描述

  1. 内存高效的宏观设计:文章提出了一种新的宏观设计,使用更大的步幅(16x16)进行patch embedding,并采用3阶段结构,减少了早期阶段的空间冗余,从而降低了内存访问成本。
  2. 单头注意力机制(SHSA):提出了一种单头自注意力模块,避免了多头注意力机制中的冗余,同时通过并行结合全局和局部信息来提升准确性。
  3. 综合性能优化:通过结合上述设计,提出了SHViT(单头视觉Transformer),在多种设备上实现了速度和准确性的最佳平衡。

方法

  1. 宏观设计分析:通过实验分析了传统4x4 patch embedding和4阶段结构的冗余,发现使用16x16 patch embedding和3阶段设计可以在减少计算成本的同时保持性能。
  2. 微观设计分析:深入研究了多头自注意力机制(MHSA)中的冗余,发现多头机制在后期的冗余尤为明显,提出了单头自注意力模块(SHSA)来替代。
  3. SHViT架构:基于上述分析,提出了SHViT架构,包括四个3x3步幅卷积层、三个阶段的SHViT块(包含深度卷积、SHSA和FFN模块),以及高效的降采样层。

SHSA模块的作用

在这里插入图片描述

  1. 减少计算冗余:SHSA模块仅对输入通道的一部分应用单头自注意力,避免了多头机制中的计算冗余。
  2. 降低内存访问成本:通过处理部分通道,减少了内存访问成本,使得模型在GPU和CPU上能够更高效地运行。
  3. 提升性能:SHSA模块通过并行结合全局和局部信息,提升了模型的准确性,同时允许在相同的计算预算下堆叠更多的块,进一步提升了性能。

源码与注释

import torch


class GroupNorm(torch.nn.GroupNorm):
    """
    Group Normalization with 1 group.
    Input: tensor in shape [B, C, H, W]
    """
    def __init__(self, num_channels, **kwargs):
        super().__init__(1, num_channels, **kwargs)


class Conv2d_BN(torch.nn.Sequential):
    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
                 groups=1, bn_weight_init=1):
        super().__init__()
        # 添加卷积层
        self.add_module('c', torch.nn.Conv2d(
            a, b, ks, stride, pad, dilation, groups, bias=False))
        # 添加批量归一化层
        self.add_module('bn', torch.nn.BatchNorm2d(b))
        # 初始化批量归一化层的权重和偏置
        torch.nn.init.constant_(self.bn.weight, bn_weight_init)
        torch.nn.init.constant_(self.bn.bias, 0)

    @torch.no_grad()
    def fuse(self):
        # 融合卷积层和批量归一化层
        c, bn = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = c.weight * w[:, None, None, None]
        b = bn.bias - bn.running_mean * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
            device=c.weight.device)
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m


class SHSA(torch.nn.Module):
    """Single-Head Self-Attention"""

    def __init__(self, dim, qk_dim=16, pdim=32):
        super().__init__()
        # 计算缩放因子
        self.scale = qk_dim ** -0.5
        self.qk_dim = qk_dim
        self.dim = dim
        self.pdim = pdim

        # 添加预归一化层
        self.pre_norm = GroupNorm(pdim)

        # 添加卷积层用于生成查询、键和值
        self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)
        # 添加投影层
        self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
            dim, dim, bn_weight_init=0))

    def forward(self, x):
        B, C, H, W = x.shape
        # 将输入张量按通道维度拆分为两部分
        x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim=1)
        # 对第一部分应用预归一化
        x1 = self.pre_norm(x1)
        # 生成查询、键和值
        qkv = self.qkv(x1)
        q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim=1)
        q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)

        # 计算注意力分数
        attn = (q.transpose(-2, -1) @ k) * self.scale
        # 应用softmax函数
        attn = attn.softmax(dim=-1)
        # 计算加权和
        x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)
        # 将加权和与未处理的部分拼接并应用投影层
        x = self.proj(torch.cat([x1, x2], dim=1))

        return x


if __name__ == '__main__':
    # 创建SHSA模块实例
    block = SHSA(64)  # 输入通道数C

    # 创建随机输入张量
    input = torch.randn(1, 64, 32, 32)  # 输入形状为[B, C, H, W]

    # 打印输入张量的形状
    print(input.size())

    # 通过SHSA模块进行前向传播
    output = block(input)

    # 打印输出张量的形状
    print(output.size())

总结

文章通过系统分析现有设计中的冗余,提出了内存高效的宏观设计和单头自注意力模块,构建了SHViT模型,在多种设备上实现了速度和准确性的最佳平衡。SHSA模块通过减少计算冗余和内存访问成本,提升了模型的性能和效率。


在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

猜你喜欢

转载自blog.csdn.net/qq_42589613/article/details/146243843
今日推荐