【即插即用涨点模块】EGA边缘引导注意力:有效保留高频边缘信息,提升分割精度,助力高效涨点【附源码+注释】

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称 项目名称
1.【人脸识别与管理系统开发 2.【车牌识别与自动收费管理系统开发
3.【手势识别系统开发 4.【人脸面部活体检测系统开发
5.【图片风格快速迁移软件开发 6.【人脸表表情识别系统
7.【YOLOv8多目标识别与自动标注软件开发 8.【基于深度学习的行人跌倒检测系统
9.【基于深度学习的PCB板缺陷检测系统 10.【基于深度学习的生活垃圾分类目标检测系统
11.【基于深度学习的安全帽目标检测系统 12.【基于深度学习的120种犬类检测与识别系统
13.【基于深度学习的路面坑洞检测系统 14.【基于深度学习的火焰烟雾检测系统
15.【基于深度学习的钢材表面缺陷检测系统 16.【基于深度学习的舰船目标分类检测系统
17.【基于深度学习的西红柿成熟度检测系统 18.【基于深度学习的血细胞检测与计数系统
19.【基于深度学习的吸烟/抽烟行为检测系统 20.【基于深度学习的水稻害虫检测与识别系统
21.【基于深度学习的高精度车辆行人检测与计数系统 22.【基于深度学习的路面标志线检测与识别系统
23.【基于深度学习的智能小麦害虫检测识别系统 24.【基于深度学习的智能玉米害虫检测识别系统
25.【基于深度学习的200种鸟类智能检测与识别系统 26.【基于深度学习的45种交通标志智能检测与识别系统
27.【基于深度学习的人脸面部表情识别系统 28.【基于深度学习的苹果叶片病害智能诊断系统
29.【基于深度学习的智能肺炎诊断系统 30.【基于深度学习的葡萄簇目标检测系统
31.【基于深度学习的100种中草药智能识别系统 32.【基于深度学习的102种花卉智能识别系统
33.【基于深度学习的100种蝴蝶智能识别系统 34.【基于深度学习的水稻叶片病害智能诊断系统
35.【基于与ByteTrack的车辆行人多目标检测与追踪系统 36.【基于深度学习的智能草莓病害检测与分割系统
37.【基于深度学习的复杂场景下船舶目标检测系统 38.【基于深度学习的农作物幼苗与杂草检测系统
39.【基于深度学习的智能道路裂缝检测与分析系统 40.【基于深度学习的葡萄病害智能诊断与防治系统
41.【基于深度学习的遥感地理空间物体检测系统 42.【基于深度学习的无人机视角地面物体检测系统
43.【基于深度学习的木薯病害智能诊断与防治系统 44.【基于深度学习的野外火焰烟雾检测系统
45.【基于深度学习的脑肿瘤智能检测系统 46.【基于深度学习的玉米叶片病害智能诊断与防治系统
47.【基于深度学习的橙子病害智能诊断与防治系统 48.【基于深度学习的车辆检测追踪与流量计数系统
49.【基于深度学习的行人检测追踪与双向流量计数系统 50.【基于深度学习的反光衣检测与预警系统
51.【基于深度学习的危险区域人员闯入检测与报警系统 52.【基于深度学习的高密度人脸智能检测与统计系统
53.【基于深度学习的CT扫描图像肾结石智能检测系统 54.【基于深度学习的水果智能检测系统
55.【基于深度学习的水果质量好坏智能检测系统 56.【基于深度学习的蔬菜目标检测与识别系统
57.【基于深度学习的非机动车驾驶员头盔检测系统 58.【太基于深度学习的阳能电池板检测与分析系统
59.【基于深度学习的工业螺栓螺母检测 60.【基于深度学习的金属焊缝缺陷检测系统
61.【基于深度学习的链条缺陷检测与识别系统 62.【基于深度学习的交通信号灯检测识别
63.【基于深度学习的草莓成熟度检测与识别系统 64.【基于深度学习的水下海生物检测识别系统
65.【基于深度学习的道路交通事故检测识别系统 66.【基于深度学习的安检X光危险品检测与识别系统
67.【基于深度学习的农作物类别检测与识别系统 68.【基于深度学习的危险驾驶行为检测识别系统
69.【基于深度学习的维修工具检测识别系统 70.【基于深度学习的维修工具检测识别系统
71.【基于深度学习的建筑墙面损伤检测系统 72.【基于深度学习的煤矿传送带异物检测系统
73.【基于深度学习的老鼠智能检测系统 74.【基于深度学习的水面垃圾智能检测识别系统
75.【基于深度学习的遥感视角船只智能检测系统 76.【基于深度学习的胃肠道息肉智能检测分割与诊断系统
77.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统 78.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统
79.【基于深度学习的果园苹果检测与计数系统 80.【基于深度学习的半导体芯片缺陷检测系统
81.【基于深度学习的糖尿病视网膜病变检测与诊断系统 82.【基于深度学习的运动鞋品牌检测与识别系统
83.【基于深度学习的苹果叶片病害检测识别系统 84.【基于深度学习的医学X光骨折检测与语音提示系统
85.【基于深度学习的遥感视角农田检测与分割系统 86.【基于深度学习的运动品牌LOGO检测与识别系统
87.【基于深度学习的电瓶车进电梯检测与语音提示系统 88.【基于深度学习的遥感视角地面房屋建筑检测分割与分析系统
89.【基于深度学习的医学CT图像肺结节智能检测与语音提示系统 90.【基于深度学习的舌苔舌象检测识别与诊断系统
91.【基于深度学习的蛀牙智能检测与语音提示系统 92.【基于深度学习的皮肤癌智能检测与语音提示系统

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

在这里插入图片描述

论文地址:https://arxiv.org/abs/2309.03329
代码地址:https://github.com/UARK-AICV/MEGANet

摘要

在这里插入图片描述

本文提出了一种名为**多尺度边缘引导注意力网络(MEGANet)**的新方法,用于结肠镜图像中的息肉分割。息肉分割在结直肠癌的早期诊断中起着至关重要的作用,但由于背景复杂、息肉大小和形状多变以及边界模糊,分割任务面临诸多挑战。MEGANet通过结合经典的边缘检测技术和注意力机制,有效保留了高频信息(如边缘和边界),从而提高了分割精度。该方法在五个基准数据集上进行了广泛的实验,结果表明MEGANet在六种评估指标上均优于现有的最先进方法。

创新点

  1. 边缘引导注意力模块(EGA):MEGANet的核心创新是引入了EGA模块,该模块利用拉普拉斯算子来增强边缘信息,解决了弱边界分割问题。

  2. 多尺度边缘信息保留:EGA模块在多个尺度上操作,从低层到高层特征,确保模型能够关注边缘相关信息,从而在每个解码器层次上提升预测精度。

  3. 无参数方法:使用拉普拉斯算子作为无参数方法,有效提取和保留高频边缘信息,避免了传统CNN方法在边缘提取上的不足。

方法总结

MEGANet是一个端到端的框架,包含三个主要模块:

  1. 编码器:负责从输入图像中捕获和抽象特征。
  2. 解码器:专注于提取显著特征,生成与输入图像分辨率匹配的解码图。
  3. 边缘引导注意力模块(EGA):利用拉普拉斯算子增强边缘信息,确保在解码过程中保留高频细节。

MEGANet通过结合编码器、解码器和EGA模块,能够在多个尺度上保留边缘信息,从而提高了息肉分割的精度。

EGA模块的作用

在这里插入图片描述
EGA模块的主要作用是通过拉普拉斯算子提取和保留高频边缘信息,增强模型对弱边界的检测能力。具体来说,EGA模块在每一层接收三个输入:

  1. 编码器特征:来自编码器的视觉特征。
  2. 高频特征:通过拉普拉斯算子提取的边缘信息。
  3. 解码器预测特征:来自更高层的解码器预测特征。

EGA模块通过结合这些输入,生成一个融合特征,该特征能够突出边缘细节,并通过注意力机制引导模型关注关键区域,从而提升分割精度。此外,EGA模块还通过卷积块注意力模块(CBAM)进一步校准特征,确保模型能够准确捕捉边界和背景区域的相关性。

总结

MEGANet通过引入EGA模块,有效解决了息肉分割中的弱边界问题,显著提高了分割精度。该方法在多个数据集上的实验结果表明其优越性,为结直肠癌的早期诊断提供了有力的技术支持。

EGA源码与注释

# Github地址:https://github.com/UARK-AICV/MEGANet
# 论文:MEGANet: Multi-Scale Edge-Guided Attention Network for Weak Boundary Polyp Segmentation, WACV 2024
# 论文地址:https://arxiv.org/abs/2309.03329

import torch
import torch.nn.functional as F
import torch.nn as nn

# 定义高斯核函数,用于生成高斯模糊滤波器
def gauss_kernel(channels=3, cuda=True):
    # 创建一个5x5的高斯核
    kernel = torch.tensor([[1., 4., 6., 4., 1],
                           [4., 16., 24., 16., 4.],
                           [6., 24., 36., 24., 6.],
                           [4., 16., 24., 16., 4.],
                           [1., 4., 6., 4., 1.]])
    # 归一化高斯核
    kernel /= 256.
    # 将高斯核扩展到多个通道
    kernel = kernel.repeat(channels, 1, 1, 1)
    if cuda:
        # 如果使用GPU,将高斯核移动到GPU
        kernel = kernel.cuda()
    return kernel

# 定义下采样函数,通过每隔一个像素取值实现
def downsample(x):
    return x[:, :, ::2, ::2]

# 定义卷积高斯模糊函数,使用高斯核对图像进行模糊处理
def conv_gauss(img, kernel):
    # 使用反射填充图像边缘
    img = F.pad(img, (2, 2, 2, 2), mode='reflect')
    # 应用卷积操作进行高斯模糊
    out = F.conv2d(img, kernel, groups=img.shape[1])
    return out

# 定义上采样函数,通过插入零值实现
def upsample(x, channels):
    # 在每个像素之间插入零值
    cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
    cc = cc.permute(0, 1, 3, 2)
    cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
    x_up = cc.permute(0, 1, 3, 2)
    # 对上采样后的图像应用高斯模糊
    return conv_gauss(x_up, 4 * gauss_kernel(channels))

# 定义拉普拉斯金字塔的一个层级,计算图像与高斯模糊后上采样的图像的差异
def make_laplace(img, channels):
    # 对图像进行高斯模糊
    filtered = conv_gauss(img, gauss_kernel(channels))
    # 对模糊后的图像进行下采样
    down = downsample(filtered)
    # 对下采样后的图像进行上采样
    up = upsample(down, channels)
    # 如果上采样后的图像尺寸与原图不同,进行插值调整
    if up.shape[2] != img.shape[2] or up.shape[3] != img.shape[3]:
        up = nn.functional.interpolate(up, size=(img.shape[2], img.shape[3]))
    # 计算原图与上采样后的图像的差异
    diff = img - up
    return diff

# 构建拉普拉斯金字塔,包含多个层级的差异图像和最终的下采样图像
def make_laplace_pyramid(img, level, channels):
    current = img
    pyr = []
    for _ in range(level):
        # 对当前图像计算拉普拉斯层级
        filtered = conv_gauss(current, gauss_kernel(channels))
        down = downsample(filtered)
        up = upsample(down, channels)
        if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
            up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
        diff = current - up
        pyr.append(diff)
        current = down
    # 最后一个层级为最终的下采样图像
    pyr.append(current)
    return pyr

# 定义通道注意力模块,用于计算通道级别的注意力权重
class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        # 定义MLP网络,用于计算通道注意力权重
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )

    def forward(self, x):
        # 计算平均池化后的通道注意力权重
        avg_out = self.mlp(F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))
        # 计算最大池化后的通道注意力权重
        max_out = self.mlp(F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))
        # 将平均池化和最大池化后的权重相加
        channel_att_sum = avg_out + max_out

        # 将权重通过sigmoid函数归一化,并扩展到与输入相同的尺寸
        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        # 将权重应用到输入特征图
        return x * scale

# 定义空间注意力模块,用于计算空间级别的注意力权重
class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        # 定义卷积层,用于计算空间注意力权重
        self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2)

    def forward(self, x):
        # 计算最大池化和平均池化后的特征图
        x_compress = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
        # 将特征图通过卷积层计算空间注意力权重
        x_out = self.spatial(x_compress)
        # 将权重通过sigmoid函数归一化
        scale = torch.sigmoid(x_out)  # broadcasting
        # 将权重应用到输入特征图
        return x * scale

# 定义CBAM模块,结合通道注意力和空间注意力
class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16):
        super(CBAM, self).__init__()
        # 初始化通道注意力模块
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio)
        # 初始化空间注意力模块
        self.SpatialGate = SpatialGate()

    def forward(self, x):
        # 应用通道注意力
        x_out = self.ChannelGate(x)
        # 应用空间注意力
        x_out = self.SpatialGate(x_out)
        return x_out

# 定义Edge-Guided Attention Module(EGA)模块,用于结合边缘信息和预测结果进行特征融合
class EGA(nn.Module):
    def __init__(self, in_channels):
        super(EGA, self).__init__()

        # 定义特征融合卷积层
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True))

        # 定义注意力机制卷积层
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, 1, 3, 1, 1),
            nn.BatchNorm2d(1),
            nn.Sigmoid())

        # 初始化CBAM模块
        self.cbam = CBAM(in_channels)

    def forward(self, edge_feature, x, pred):
        residual = x
        xsize = x.size()[2:]

        # 将预测结果通过sigmoid函数归一化
        pred = torch.sigmoid(pred)

        # 计算背景注意力权重
        background_att = 1 - pred
        # 应用背景注意力权重到特征图
        background_x = x * background_att

        # 计算边界注意力权重
        edge_pred = make_laplace(pred, 1)
        # 应用边界注意力权重到特征图
        pred_feature = x * edge_pred

        # 计算高频特征
        edge_input = F.interpolate(edge_feature, size=xsize, mode='bilinear', align_corners=True)
        # 应用高频特征到特征图
        input_feature = x * edge_input

        # 将背景特征、边界特征和高频特征进行拼接
        fusion_feature = torch.cat([background_x, pred_feature, input_feature], dim=1)
        # 应用特征融合卷积层
        fusion_feature = self.fusion_conv(fusion_feature)

        # 计算注意力权重
        attention_map = self.attention(fusion_feature)
        # 应用注意力权重到融合特征
        fusion_feature = fusion_feature * attention_map

        # 将融合特征与残差相加
        out = fusion_feature + residual
        # 应用CBAM模块
        out = self.cbam(out)
        return out

if __name__ == '__main__':
    # 模拟输入张量
    edge_feature = torch.randn(1, 1, 128, 128).cuda()
    x = torch.randn(1, 64, 128, 128).cuda()
    pred = torch.randn(1, 1, 128, 128).cuda()  # pred 通常是1通道

    # 实例化 EGA 类
    block = EGA(64).cuda()

    # 传递输入张量通过 EGA 实例
    output = block(edge_feature, x, pred)

    # 打印输入和输出的形状
    print(edge_feature.size())
    print(x.size())
    print(pred.size())
    print(output.size())

在这里插入图片描述

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

猜你喜欢

转载自blog.csdn.net/qq_42589613/article/details/146328041
今日推荐