目录
1. 前言
注意力模块 (Attention Module) 是一种在深度学习模型中引入的一种机制,用于提高模型对重要信息的关注度。该模块模拟了人类在处理信息时的注意机制,通过动态地分配注意力来选择模型需要关注的部分。

在传统的深度学习模型中,每个输入特征都被平等地对待,没有考虑到不同特征的重要性。而在Attention模块中,模型能够基于任务的需求自动地选择性地关注输入特征的不同部分。这可以使得模型更加集中地处理重要的特征,提高模型的表现力和泛化能力。
Attention模块的核心思想是通过计算输入特征与一组学习参数之间的相似度(通常使用点积或者其他形式的相似度度量),来得到每个特征的权重。这些权重可以视作每个特征的注意力分数,表示模型在处理过程中应该关注的程度。然后,将特征与权重相乘,得到加权特征表示,用于后续的处理和预测。
Attention模块在各种深度学习任务中都得到了广泛应用,比如机器翻译、图像描述生成、问答系统等。它不仅可以提高模型的性能,还能帮助解释模型的决策过程,提供更好的可解释性。因此,Attention模块成为了深度学习模型中一个重要的组成部分。
加入attention以后:
相关代码如下:
class Attention_block(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(Attention_block, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
2. 代码介绍
训练参数如下:
parser.add_argument("--model", default='unet', type=str,help='unet,attunet') # 模型
parser.add_argument("--img-size", default=[224,224],help='input image size') # 输入图像的size
parser.add_argument("--ct", default=False,type=bool,help='is CT?') # Ct --> True
parser.add_argument("--batch-size", default=8, type=int)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument('--lr', default=0.01, type=float)
parser.add_argument('--lrf',default=0.001,type=float) # 最终学习率 = lr * lrf
parser.add_argument("--img_f", default='.jpg', type=str) # 数据图像的后缀
parser.add_argument("--mask_f", default='.png', type=str) # mask图像的后缀
如果想要做unet和attentionUnet的实验对比,只需要把model的参数进行更换即可。
img-size参数是经过预处理后输入网络图像的尺寸,理论上说img-size参数越大,分辨率越高,效果会更好,但是更大的参数导致网络的计算量大、占用内存高等等,可以自行尝试
因为不少数据是CT数据,这里对单独的CT数据进行了windowing处理,只需要设置为True即可,处理脚本如下:
if self.window_CT:
lower = 40 - 400 / 2 # 窗口等级 - 窗口宽度 / 2
upper = 40 + 400 / 2
image = np.clip(image,lower,upper)
区别于以往的txt保存训练信息,这里采用了json文件,这样想要自行绘图的话,直接读取即可
如下:这里代码会自动会训练集和验证集进行评估,指标有precision、recall、dice、iou等等
"epoch:99": {
"train log:": {
"info": {
"pixel accuracy": [
0.9959836006164551
],
"Precision": [
"0.9457",
"0.9407",
"0.9598",
"0.9606",
"0.9815",
"0.8911",
"0.9787",
"0.9688"
],
"Recall": [
"0.9457",
"0.9407",
"0.9598",
"0.9606",
"0.9815",
"0.8911",
"0.9787",
"0.9688"
],
"F1 score": [
"0.9491",
"0.9286",
"0.9586",
"0.9600",
"0.9806",
"0.8952",
"0.9781",
"0.9675"
],
"Dice": [
"0.9491",
"0.9286",
"0.9586",
"0.9600",
"0.9806",
"0.8952",
"0.9781",
"0.9675"
],
"IoU": [
"0.9032",
"0.8668",
"0.9204",
"0.9230",
"0.9620",
"0.8102",
"0.9572",
"0.9371"
],
"mean precision": 0.9511153697967529,
"mean recall": 0.9533761739730835,
"mean f1 score": 0.9522219896316528,
"mean dice": 0.9522219896316528,
"mean iou": 0.9099960923194885
}
},
"val log:": {
"info": {
"pixel accuracy": [
0.9944279789924622
],
"Precision": [
"0.9357",
"0.8845",
"0.9496",
"0.9482",
"0.9740",
"0.8008",
"0.9741",
"0.9519"
],
"Recall": [
"0.9357",
"0.8845",
"0.9496",
"0.9482",
"0.9740",
"0.8008",
"0.9741",
"0.9519"
],
"F1 score": [
"0.9331",
"0.8817",
"0.9505",
"0.9505",
"0.9738",
"0.8379",
"0.9736",
"0.9478"
],
"Dice": [
"0.9331",
"0.8817",
"0.9505",
"0.9505",
"0.9738",
"0.8379",
"0.9736",
"0.9478"
],
"IoU": [
"0.8746",
"0.7884",
"0.9056",
"0.9057",
"0.9489",
"0.7211",
"0.9486",
"0.9008"
],
"mean precision": 0.9353601336479187,
"mean recall": 0.9273472428321838,
"mean f1 score": 0.9311231374740601,
"mean dice": 0.9311231374740601,
"mean iou": 0.8742245435714722
}
}
}
}
需要注意的是:例如precision这种是列表的形式,是因为这个数据是多分类的,这样里面每一个值都是标签从小到大的值的precision值,这里是不包括背景的。也就是说,这里列表有8个值,所以网络的输出是(8+1)9输出的
推理的时候,参数就那么几个
parser.add_argument("--model", default='unet', type=str,help='unet,attunet') # 模型
parser.add_argument("--img-size", default=[224,224],help='input image size') # 输入图像的size
parser.add_argument("--ct", default=False,type=bool,help='is CT?') # Ct --> True
parser.add_argument("--pth", default='runs/weights/best.pth')
这里只需要把想要推理的数据放在inference/img下,代码会自动推理,并且保存推理的gt在infer_gt中,并且绘制img+gt的可视化结果在show目录下
img 存放待推理的数据:
infer_gt 是推理的结果(阈值图像)
show 是img+推理gt的叠加效果
接下来将对二值分割和多类别分割进行实验对比,具体更换数据子参考README文件
3. 人体脊柱二值分割
下载链接:基于Unet改进系列:加入attenUnet模块对脊椎二值图像语义分割的实验对比资源-CSDN文库
数据集如下:
3.1 unet 网络
指标如下:
"epoch:95": {
"train log:": {
"info": {
"pixel accuracy": [
0.997604489326477
],
"Precision": [
"0.9515"
],
"Recall": [
"0.9515"
],
"F1 score": [
"0.9536"
],
"Dice": [
"0.9536"
],
"IoU": [
"0.9114"
],
"mean precision": 0.9557837843894958,
"mean recall": 0.9514825344085693,
"mean f1 score": 0.9536282420158386,
"mean dice": 0.9536283016204834,
"mean iou": 0.9113666415214539
}
},
"val log:": {
"info": {
"pixel accuracy": [
0.9972344636917114
],
"Precision": [
"0.9388"
],
"Recall": [
"0.9388"
],
"F1 score": [
"0.9452"
],
"Dice": [
"0.9452"
],
"IoU": [
"0.8961"
],
"mean precision": 0.951705813407898,
"mean recall": 0.938778281211853,
"mean f1 score": 0.9451978206634521,
"mean dice": 0.9451978802680969,
"mean iou": 0.8960902094841003
}
}
}
推理的结果:
3.2 attentionUnet
如下:
"epoch:93": {
"train log:": {
"info": {
"pixel accuracy": [
0.9977812767028809
],
"Precision": [
"0.9543"
],
"Recall": [
"0.9543"
],
"F1 score": [
"0.9570"
],
"Dice": [
"0.9570"
],
"IoU": [
"0.9176"
],
"mean precision": 0.9597753882408142,
"mean recall": 0.9542863368988037,
"mean f1 score": 0.9570229649543762,
"mean dice": 0.9570229649543762,
"mean iou": 0.9175877571105957
}
},
"val log:": {
"info": {
"pixel accuracy": [
0.9972677230834961
],
"Precision": [
"0.9386"
],
"Recall": [
"0.9386"
],
"F1 score": [
"0.9458"
],
"Dice": [
"0.9458"
],
"IoU": [
"0.8972"
],
"mean precision": 0.9531528353691101,
"mean recall": 0.9385806322097778,
"mean f1 score": 0.9458106160163879,
"mean dice": 0.9458106160163879,
"mean iou": 0.8971922993659973
}
}
},
推理结果:
4. Synapse多器官8分割
下载链接:基于Unet改进系列:加入attenUnet模块对Synapse多器官图像语义分割的实验对比资源-CSDN文库
8个腹部器官(主动脉、胆囊、脾、左肾、右肾、肝、胰腺、脾、胃)
4.1 unet 网络
指标如下:
"epoch:95": {
"train log:": {
"info": {
"pixel accuracy": [
0.9959646463394165
],
"Precision": [
"0.9462",
"0.9382",
"0.9599",
"0.9614",
"0.9815",
"0.8921",
"0.9789",
"0.9690"
],
"Recall": [
"0.9462",
"0.9382",
"0.9599",
"0.9614",
"0.9815",
"0.8921",
"0.9789",
"0.9690"
],
"F1 score": [
"0.9488",
"0.9288",
"0.9584",
"0.9597",
"0.9805",
"0.8946",
"0.9780",
"0.9676"
],
"Dice": [
"0.9488",
"0.9288",
"0.9584",
"0.9597",
"0.9805",
"0.8946",
"0.9780",
"0.9676"
],
"IoU": [
"0.9025",
"0.8671",
"0.9202",
"0.9226",
"0.9618",
"0.8093",
"0.9569",
"0.9373"
],
"mean precision": 0.9507473111152649,
"mean recall": 0.9533960819244385,
"mean f1 score": 0.9520571827888489,
"mean dice": 0.9520571827888489,
"mean iou": 0.9097038507461548
}
},
"val log:": {
"info": {
"pixel accuracy": [
0.9943745136260986
],
"Precision": [
"0.9304",
"0.8901",
"0.9501",
"0.9475",
"0.9732",
"0.8006",
"0.9751",
"0.9502"
],
"Recall": [
"0.9304",
"0.8901",
"0.9501",
"0.9475",
"0.9732",
"0.8006",
"0.9751",
"0.9502"
],
"F1 score": [
"0.9282",
"0.8822",
"0.9505",
"0.9501",
"0.9736",
"0.8388",
"0.9742",
"0.9464"
],
"Dice": [
"0.9282",
"0.8822",
"0.9505",
"0.9501",
"0.9736",
"0.8388",
"0.9742",
"0.9464"
],
"IoU": [
"0.8660",
"0.7893",
"0.9056",
"0.9049",
"0.9486",
"0.7224",
"0.9498",
"0.8982"
],
"mean precision": 0.9343648552894592,
"mean recall": 0.9271403551101685,
"mean f1 score": 0.9305001497268677,
"mean dice": 0.9305001497268677,
"mean iou": 0.8730905652046204
}
}
},
推理结果:
4.2 attentionUnet
指标如下:
"epoch:89": {
"train log:": {
"info": {
"pixel accuracy": [
0.9960875511169434
],
"Precision": [
"0.9464",
"0.9375",
"0.9598",
"0.9601",
"0.9821",
"0.8991",
"0.9799",
"0.9702"
],
"Recall": [
"0.9464",
"0.9375",
"0.9598",
"0.9601",
"0.9821",
"0.8991",
"0.9799",
"0.9702"
],
"F1 score": [
"0.9490",
"0.9281",
"0.9591",
"0.9602",
"0.9810",
"0.8999",
"0.9791",
"0.9692"
],
"Dice": [
"0.9490",
"0.9281",
"0.9591",
"0.9602",
"0.9810",
"0.8999",
"0.9791",
"0.9692"
],
"IoU": [
"0.9029",
"0.8658",
"0.9215",
"0.9234",
"0.9627",
"0.8180",
"0.9590",
"0.9403"
],
"mean precision": 0.9520106315612793,
"mean recall": 0.9544010162353516,
"mean f1 score": 0.9531928300857544,
"mean dice": 0.9531928300857544,
"mean iou": 0.9116977453231812
}
},
"val log:": {
"info": {
"pixel accuracy": [
0.9945223927497864
],
"Precision": [
"0.9200",
"0.8879",
"0.9491",
"0.9491",
"0.9733",
"0.8263",
"0.9763",
"0.9453"
],
"Recall": [
"0.9200",
"0.8879",
"0.9491",
"0.9491",
"0.9733",
"0.8263",
"0.9763",
"0.9453"
],
"F1 score": [
"0.9308",
"0.8870",
"0.9512",
"0.9536",
"0.9736",
"0.8529",
"0.9751",
"0.9462"
],
"Dice": [
"0.9308",
"0.8870",
"0.9512",
"0.9536",
"0.9736",
"0.8529",
"0.9751",
"0.9462"
],
"IoU": [
"0.8705",
"0.7970",
"0.9069",
"0.9114",
"0.9485",
"0.7436",
"0.9515",
"0.8979"
],
"mean precision": 0.9394575357437134,
"mean recall": 0.9284132122993469,
"mean f1 score": 0.9338046908378601,
"mean dice": 0.9338046908378601,
"mean iou": 0.8784005641937256
}
}
},
推理结果: