改进系列(2):Unet加入attention模块与unet的实验对比(脊柱分割、Synapse多器官分割)

目录

1. 前言

2. 代码介绍

3. 人体脊柱二值分割

3.1 unet 网络

3.2 attentionUnet

4. Synapse多器官8分割

4.1 unet 网络

4.2 attentionUnet


1. 前言

注意力模块 (Attention Module) 是一种在深度学习模型中引入的一种机制,用于提高模型对重要信息的关注度。该模块模拟了人类在处理信息时的注意机制,通过动态地分配注意力来选择模型需要关注的部分。

扫描二维码关注公众号,回复: 17415900 查看本文章

在传统的深度学习模型中,每个输入特征都被平等地对待,没有考虑到不同特征的重要性。而在Attention模块中,模型能够基于任务的需求自动地选择性地关注输入特征的不同部分。这可以使得模型更加集中地处理重要的特征,提高模型的表现力和泛化能力。

Attention模块的核心思想是通过计算输入特征与一组学习参数之间的相似度(通常使用点积或者其他形式的相似度度量),来得到每个特征的权重。这些权重可以视作每个特征的注意力分数,表示模型在处理过程中应该关注的程度。然后,将特征与权重相乘,得到加权特征表示,用于后续的处理和预测。

Attention模块在各种深度学习任务中都得到了广泛应用,比如机器翻译、图像描述生成、问答系统等。它不仅可以提高模型的性能,还能帮助解释模型的决策过程,提供更好的可解释性。因此,Attention模块成为了深度学习模型中一个重要的组成部分。

加入attention以后:

相关代码如下:

class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)

        return x * psi

2. 代码介绍

训练参数如下:

    parser.add_argument("--model", default='unet', type=str,help='unet,attunet')    # 模型

    parser.add_argument("--img-size", default=[224,224],help='input image size')    # 输入图像的size
    parser.add_argument("--ct", default=False,type=bool,help='is CT?')    # Ct --> True

    parser.add_argument("--batch-size", default=8, type=int)
    parser.add_argument("--epochs", default=100, type=int)

    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lrf',default=0.001,type=float)                  # 最终学习率 = lr * lrf

    parser.add_argument("--img_f", default='.jpg', type=str)               # 数据图像的后缀
    parser.add_argument("--mask_f", default='.png', type=str)              # mask图像的后缀

如果想要做unet和attentionUnet的实验对比,只需要把model的参数进行更换即可。

img-size参数是经过预处理后输入网络图像的尺寸,理论上说img-size参数越大,分辨率越高,效果会更好,但是更大的参数导致网络的计算量大、占用内存高等等,可以自行尝试

因为不少数据是CT数据,这里对单独的CT数据进行了windowing处理,只需要设置为True即可,处理脚本如下:

        if self.window_CT:
            lower = 40 - 400 / 2      # 窗口等级 - 窗口宽度 / 2
            upper = 40 + 400 / 2
            image = np.clip(image,lower,upper)

区别于以往的txt保存训练信息,这里采用了json文件,这样想要自行绘图的话,直接读取即可

如下:这里代码会自动会训练集和验证集进行评估,指标有precision、recall、dice、iou等等

    "epoch:99": {
        "train log:": {
            "info": {
                "pixel accuracy": [
                    0.9959836006164551
                ],
                "Precision": [
                    "0.9457",
                    "0.9407",
                    "0.9598",
                    "0.9606",
                    "0.9815",
                    "0.8911",
                    "0.9787",
                    "0.9688"
                ],
                "Recall": [
                    "0.9457",
                    "0.9407",
                    "0.9598",
                    "0.9606",
                    "0.9815",
                    "0.8911",
                    "0.9787",
                    "0.9688"
                ],
                "F1 score": [
                    "0.9491",
                    "0.9286",
                    "0.9586",
                    "0.9600",
                    "0.9806",
                    "0.8952",
                    "0.9781",
                    "0.9675"
                ],
                "Dice": [
                    "0.9491",
                    "0.9286",
                    "0.9586",
                    "0.9600",
                    "0.9806",
                    "0.8952",
                    "0.9781",
                    "0.9675"
                ],
                "IoU": [
                    "0.9032",
                    "0.8668",
                    "0.9204",
                    "0.9230",
                    "0.9620",
                    "0.8102",
                    "0.9572",
                    "0.9371"
                ],
                "mean precision": 0.9511153697967529,
                "mean recall": 0.9533761739730835,
                "mean f1 score": 0.9522219896316528,
                "mean dice": 0.9522219896316528,
                "mean iou": 0.9099960923194885
            }
        },
        "val log:": {
            "info": {
                "pixel accuracy": [
                    0.9944279789924622
                ],
                "Precision": [
                    "0.9357",
                    "0.8845",
                    "0.9496",
                    "0.9482",
                    "0.9740",
                    "0.8008",
                    "0.9741",
                    "0.9519"
                ],
                "Recall": [
                    "0.9357",
                    "0.8845",
                    "0.9496",
                    "0.9482",
                    "0.9740",
                    "0.8008",
                    "0.9741",
                    "0.9519"
                ],
                "F1 score": [
                    "0.9331",
                    "0.8817",
                    "0.9505",
                    "0.9505",
                    "0.9738",
                    "0.8379",
                    "0.9736",
                    "0.9478"
                ],
                "Dice": [
                    "0.9331",
                    "0.8817",
                    "0.9505",
                    "0.9505",
                    "0.9738",
                    "0.8379",
                    "0.9736",
                    "0.9478"
                ],
                "IoU": [
                    "0.8746",
                    "0.7884",
                    "0.9056",
                    "0.9057",
                    "0.9489",
                    "0.7211",
                    "0.9486",
                    "0.9008"
                ],
                "mean precision": 0.9353601336479187,
                "mean recall": 0.9273472428321838,
                "mean f1 score": 0.9311231374740601,
                "mean dice": 0.9311231374740601,
                "mean iou": 0.8742245435714722
            }
        }
    }
}

需要注意的是:例如precision这种是列表的形式,是因为这个数据是多分类的,这样里面每一个值都是标签从小到大的值的precision值,这里是不包括背景的。也就是说,这里列表有8个值,所以网络的输出是(8+1)9输出的

推理的时候,参数就那么几个

    parser.add_argument("--model", default='unet', type=str,help='unet,attunet')    # 模型

    parser.add_argument("--img-size", default=[224,224],help='input image size')    # 输入图像的size
    parser.add_argument("--ct", default=False,type=bool,help='is CT?')    # Ct --> True
    parser.add_argument("--pth", default='runs/weights/best.pth')

这里只需要把想要推理的数据放在inference/img下,代码会自动推理,并且保存推理的gt在infer_gt中,并且绘制img+gt的可视化结果在show目录下

img 存放待推理的数据:

 infer_gt 是推理的结果(阈值图像)

show 是img+推理gt的叠加效果

接下来将对二值分割和多类别分割进行实验对比,具体更换数据子参考README文件

3. 人体脊柱二值分割

下载链接:基于Unet改进系列:加入attenUnet模块对脊椎二值图像语义分割的实验对比资源-CSDN文库

数据集如下:

3.1 unet 网络

指标如下:

    "epoch:95": {
        "train log:": {
            "info": {
                "pixel accuracy": [
                    0.997604489326477
                ],
                "Precision": [
                    "0.9515"
                ],
                "Recall": [
                    "0.9515"
                ],
                "F1 score": [
                    "0.9536"
                ],
                "Dice": [
                    "0.9536"
                ],
                "IoU": [
                    "0.9114"
                ],
                "mean precision": 0.9557837843894958,
                "mean recall": 0.9514825344085693,
                "mean f1 score": 0.9536282420158386,
                "mean dice": 0.9536283016204834,
                "mean iou": 0.9113666415214539
            }
        },
        "val log:": {
            "info": {
                "pixel accuracy": [
                    0.9972344636917114
                ],
                "Precision": [
                    "0.9388"
                ],
                "Recall": [
                    "0.9388"
                ],
                "F1 score": [
                    "0.9452"
                ],
                "Dice": [
                    "0.9452"
                ],
                "IoU": [
                    "0.8961"
                ],
                "mean precision": 0.951705813407898,
                "mean recall": 0.938778281211853,
                "mean f1 score": 0.9451978206634521,
                "mean dice": 0.9451978802680969,
                "mean iou": 0.8960902094841003
            }
        }
    }

推理的结果:

3.2 attentionUnet

如下:

  "epoch:93": {
        "train log:": {
            "info": {
                "pixel accuracy": [
                    0.9977812767028809
                ],
                "Precision": [
                    "0.9543"
                ],
                "Recall": [
                    "0.9543"
                ],
                "F1 score": [
                    "0.9570"
                ],
                "Dice": [
                    "0.9570"
                ],
                "IoU": [
                    "0.9176"
                ],
                "mean precision": 0.9597753882408142,
                "mean recall": 0.9542863368988037,
                "mean f1 score": 0.9570229649543762,
                "mean dice": 0.9570229649543762,
                "mean iou": 0.9175877571105957
            }
        },
        "val log:": {
            "info": {
                "pixel accuracy": [
                    0.9972677230834961
                ],
                "Precision": [
                    "0.9386"
                ],
                "Recall": [
                    "0.9386"
                ],
                "F1 score": [
                    "0.9458"
                ],
                "Dice": [
                    "0.9458"
                ],
                "IoU": [
                    "0.8972"
                ],
                "mean precision": 0.9531528353691101,
                "mean recall": 0.9385806322097778,
                "mean f1 score": 0.9458106160163879,
                "mean dice": 0.9458106160163879,
                "mean iou": 0.8971922993659973
            }
        }
    },

 推理结果:

4. Synapse多器官8分割

下载链接:基于Unet改进系列:加入attenUnet模块对Synapse多器官图像语义分割的实验对比资源-CSDN文库

8个腹部器官(主动脉、胆囊、脾、左肾、右肾、肝、胰腺、脾、胃)

4.1 unet 网络

指标如下:

    "epoch:95": {
        "train log:": {
            "info": {
                "pixel accuracy": [
                    0.9959646463394165
                ],
                "Precision": [
                    "0.9462",
                    "0.9382",
                    "0.9599",
                    "0.9614",
                    "0.9815",
                    "0.8921",
                    "0.9789",
                    "0.9690"
                ],
                "Recall": [
                    "0.9462",
                    "0.9382",
                    "0.9599",
                    "0.9614",
                    "0.9815",
                    "0.8921",
                    "0.9789",
                    "0.9690"
                ],
                "F1 score": [
                    "0.9488",
                    "0.9288",
                    "0.9584",
                    "0.9597",
                    "0.9805",
                    "0.8946",
                    "0.9780",
                    "0.9676"
                ],
                "Dice": [
                    "0.9488",
                    "0.9288",
                    "0.9584",
                    "0.9597",
                    "0.9805",
                    "0.8946",
                    "0.9780",
                    "0.9676"
                ],
                "IoU": [
                    "0.9025",
                    "0.8671",
                    "0.9202",
                    "0.9226",
                    "0.9618",
                    "0.8093",
                    "0.9569",
                    "0.9373"
                ],
                "mean precision": 0.9507473111152649,
                "mean recall": 0.9533960819244385,
                "mean f1 score": 0.9520571827888489,
                "mean dice": 0.9520571827888489,
                "mean iou": 0.9097038507461548
            }
        },
        "val log:": {
            "info": {
                "pixel accuracy": [
                    0.9943745136260986
                ],
                "Precision": [
                    "0.9304",
                    "0.8901",
                    "0.9501",
                    "0.9475",
                    "0.9732",
                    "0.8006",
                    "0.9751",
                    "0.9502"
                ],
                "Recall": [
                    "0.9304",
                    "0.8901",
                    "0.9501",
                    "0.9475",
                    "0.9732",
                    "0.8006",
                    "0.9751",
                    "0.9502"
                ],
                "F1 score": [
                    "0.9282",
                    "0.8822",
                    "0.9505",
                    "0.9501",
                    "0.9736",
                    "0.8388",
                    "0.9742",
                    "0.9464"
                ],
                "Dice": [
                    "0.9282",
                    "0.8822",
                    "0.9505",
                    "0.9501",
                    "0.9736",
                    "0.8388",
                    "0.9742",
                    "0.9464"
                ],
                "IoU": [
                    "0.8660",
                    "0.7893",
                    "0.9056",
                    "0.9049",
                    "0.9486",
                    "0.7224",
                    "0.9498",
                    "0.8982"
                ],
                "mean precision": 0.9343648552894592,
                "mean recall": 0.9271403551101685,
                "mean f1 score": 0.9305001497268677,
                "mean dice": 0.9305001497268677,
                "mean iou": 0.8730905652046204
            }
        }
    },

推理结果:

4.2 attentionUnet

指标如下:

    "epoch:89": {
        "train log:": {
            "info": {
                "pixel accuracy": [
                    0.9960875511169434
                ],
                "Precision": [
                    "0.9464",
                    "0.9375",
                    "0.9598",
                    "0.9601",
                    "0.9821",
                    "0.8991",
                    "0.9799",
                    "0.9702"
                ],
                "Recall": [
                    "0.9464",
                    "0.9375",
                    "0.9598",
                    "0.9601",
                    "0.9821",
                    "0.8991",
                    "0.9799",
                    "0.9702"
                ],
                "F1 score": [
                    "0.9490",
                    "0.9281",
                    "0.9591",
                    "0.9602",
                    "0.9810",
                    "0.8999",
                    "0.9791",
                    "0.9692"
                ],
                "Dice": [
                    "0.9490",
                    "0.9281",
                    "0.9591",
                    "0.9602",
                    "0.9810",
                    "0.8999",
                    "0.9791",
                    "0.9692"
                ],
                "IoU": [
                    "0.9029",
                    "0.8658",
                    "0.9215",
                    "0.9234",
                    "0.9627",
                    "0.8180",
                    "0.9590",
                    "0.9403"
                ],
                "mean precision": 0.9520106315612793,
                "mean recall": 0.9544010162353516,
                "mean f1 score": 0.9531928300857544,
                "mean dice": 0.9531928300857544,
                "mean iou": 0.9116977453231812
            }
        },
        "val log:": {
            "info": {
                "pixel accuracy": [
                    0.9945223927497864
                ],
                "Precision": [
                    "0.9200",
                    "0.8879",
                    "0.9491",
                    "0.9491",
                    "0.9733",
                    "0.8263",
                    "0.9763",
                    "0.9453"
                ],
                "Recall": [
                    "0.9200",
                    "0.8879",
                    "0.9491",
                    "0.9491",
                    "0.9733",
                    "0.8263",
                    "0.9763",
                    "0.9453"
                ],
                "F1 score": [
                    "0.9308",
                    "0.8870",
                    "0.9512",
                    "0.9536",
                    "0.9736",
                    "0.8529",
                    "0.9751",
                    "0.9462"
                ],
                "Dice": [
                    "0.9308",
                    "0.8870",
                    "0.9512",
                    "0.9536",
                    "0.9736",
                    "0.8529",
                    "0.9751",
                    "0.9462"
                ],
                "IoU": [
                    "0.8705",
                    "0.7970",
                    "0.9069",
                    "0.9114",
                    "0.9485",
                    "0.7436",
                    "0.9515",
                    "0.8979"
                ],
                "mean precision": 0.9394575357437134,
                "mean recall": 0.9284132122993469,
                "mean f1 score": 0.9338046908378601,
                "mean dice": 0.9338046908378601,
                "mean iou": 0.8784005641937256
            }
        }
    },

 推理结果:

猜你喜欢

转载自blog.csdn.net/qq_44886601/article/details/143301551