5分钟速成半监督医学图像分割

在医学图像分析的领域,迅速准确地获取诊断信息是挽救生命的关键。然而,面对复杂的医学图像和巨量的数据,传统的全监督分割方法往往因依赖大量标注数据而显得力不从心。这种局限不仅拖慢了分析速度,也增加了医学领域的工作负担。想象一下,如果我们能够在仅仅5分钟内完成高精度的医学图像分割,那将会为医疗实践带来怎样的革命性改变?

目录

概述

核心逻辑

复现过程

写在最后


概述

        这里我将介绍一篇MICCAI 2023的一篇医学图像分割的文章 地址,这篇文章提出了一种新的解耦一致性半监督医学图像分割框架。该框架充分利用预测数据,将预测数据解耦为用于各种功能的数据,并最大限度地发挥每种功能的优势:

对于半监督的医学分割任务,传统的伪标签方法会过滤掉低置信度的像素,而一致性正则化并没有充分利用高置信度和低置信度数据的优势。因此,这两种方法都不能充分利用无标签数据。这篇文章提出了一种新的解耦一致性半监督医学图像分割框架。首先,利用动态阈值将预测数据解耦为一致部分和不一致部分。对于一致部分,使用交叉伪监督的方法进行优化。对于不一致部分,进一步将其解耦为可能靠近决策边界的不可靠数据和更有可能出现在高密度区域的引导数据。不可靠数据将朝着引导数据的方向进行优化,这种操作为方向一致性。此外,为了充分利用数据,我们将特征图纳入训练过程并计算特征一致性的损失。

核心逻辑

这篇文章的模型图如下图所示:

如该图所示,DC-Net包含一个编码器和两个一致的解码器,对于A解码器,用双线性插值进行上采样,对于B解码器使用反卷积进行上采样。对于有标签的数据,计算它们与真实值之间的损失 Lseg,对于一致部分,我们计算交叉伪监督损失 Lcps,对于不一致部分,我们计算方向一致性损失 Ldc,对于特征图,我们计算特征一致性损失 Lf。 

动态一致性阈值:FlexMatch [23] 证明了在训练的早期阶段,为了提高无标签数据的利用率并促进伪标签的多样化,γ 应该相对较小。随着训练的进行,γ 应该保持一个稳定的伪标签比例,其中 B 是批量大小,λ 是随着训练增加的权重系数,我们设定 λ = t/tmax。为了采集更多的无标签数据,我们对 pA 和 pB 进行阈值评估,并选择较小的阈值作为我们的一致性阈值。我们将 λt 初始化为 1/C,其中 C 表示类别数。

分解一致性:这篇文章将不一致部分解耦为不可靠数据和引导数据,其中不可靠数据可能出现在决策边界,而引导数据更有可能出现在高密度区域。这两部分具有相同的索引信息,不同之处在于引导数据比不可靠数据更有信心。基于平滑假设,这两部分的输出应该是一致的,并且位于高密度区域。因此,我们应该集中优化决策边界周围的像素,以使其更接近高密度区域。我们首先通过锐化这些像素的置信度,使高置信度像素更接近高密度区域。

复现过程

这是我们下载下来的源码目录,然后这里作者提供了一个ACDC数据集在10%标注数据样本条件下的预训练模型,这里我们直接用这个模型进行测试:

这是我们下载下来的源码目录,然后这里作者提供了一个ACDC数据集在10%标注数据样本条件下的预训练模型,这里我们直接用这个模型进行测试,这里有一个问题就是作者将模型放在了ACDC_7这个目录下,我们只需要将其移动到ACDC_mcnet_kd_DCNet_7_labeled这个目录下就行了,然后运行test_acdc.py文件,就可以得到想要的结果。下图是实现的结果:

除了这个指标性能,我们还可以得到预测到的3D医学图像: 

这里我们可以使用一个专门的3D图像可视化的软件ITK-SNAP,下面是一些可视化的结果:

这里在运行test_acdc.py文件之前,还需要做好数据集的准备,首先需要获取ACDC和PROMISE12数据集,下面我将展示一下核心代码,也添加了一些注释进行讲解说明:

output1_soft = F.softmax(output1, dim=1)
output2_soft = F.softmax(output2, dim=1)
output1_soft0 = F.softmax(output1 / 0.5, dim=1)
output2_soft0 = F.softmax(output2 / 0.5, dim=1)
# 这里是预测输出的锐化过程
with torch.no_grad():
    max_values1, _ = torch.max(output1_soft, dim=1)
    max_values2, _ = torch.max(output2_soft, dim=1)
    percent = (iter_num + 1) / max_iterations

    cur_threshold1 = (1 - percent) * cur_threshold + percent * max_values1.mean()
    cur_threshold2 = (1 - percent) * cur_threshold + percent * max_values2.mean()
    mean_max_values = min(max_values1.mean(), max_values2.mean())

    cur_threshold = min(cur_threshold1, cur_threshold2)
    cur_threshold = torch.clip(cur_threshold, 0.25, 0.95)

mask_high = (output1_soft > cur_threshold) & (output2_soft > cur_threshold)
mask_non_similarity = (mask_high == False)
# 这里是动态阈值部分的实现,这里阈值的初始值是0.25,也就是类别的倒数,然后这个值会快速地上升,最大值为0.95. 这里由这个阈值可以得到一致的高阈值区域和不一致区域。

new_output1_soft = torch.mul(mask_non_similarity, output1_soft)
new_output2_soft = torch.mul(mask_non_similarity, output2_soft)
high_output1 = torch.mul(mask_high, output1)
high_output2 = torch.mul(mask_high, output2)
high_output1_soft = torch.mul(mask_high, output1_soft)
high_output2_soft = torch.mul(mask_high, output2_soft)

pseudo_output1 = torch.argmax(output1_soft, dim=1)
pseudo_output2 = torch.argmax(output2_soft, dim=1)
pseudo_high_output1 = torch.argmax(high_output1_soft, dim=1)
pseudo_high_output2 = torch.argmax(high_output2_soft, dim=1)

max_output1_indices = new_output1_soft > new_output2_soft  # output1 距离近的像素的位置

max_output1_value0 = torch.mul(max_output1_indices, output1_soft0)
min_output2_value0 = torch.mul(max_output1_indices, output2_soft0)

max_output2_indices = new_output2_soft > new_output1_soft  # output2 距离远的像素的位置

max_output2_value0 = torch.mul(max_output2_indices, output2_soft0)
min_output1_value0 = torch.mul(max_output2_indices, output1_soft0)
# 上面这段代码就是利用一致性区域和非一致性区域的处理过程

loss_dc0 = 0
loss_cer = 0
loss_at_kd = criterion_att(encoder_features, decoder_features2)


loss_dc0 += mse_criterion(max_output1_value0.detach(), min_output2_value0)
loss_dc0 += mse_criterion(max_output2_value0.detach(), min_output1_value0)

loss_seg_dice += dice_loss(output1_soft[:labeled_bs, ...], label_batch[:labeled_bs].unsqueeze(1))
loss_seg_dice += dice_loss(output2_soft[:labeled_bs, ...], label_batch[:labeled_bs].unsqueeze(1))


if mean_max_values >= 0.95:
     loss_cer += ce_loss(output1, pseudo_output2.long().detach())
     loss_cer += ce_loss(output2, pseudo_output1.long().detach())
else:
     loss_cer += ce_loss(high_output1, pseudo_high_output2.long().detach())
     loss_cer += ce_loss(high_output2, pseudo_high_output1.long().detach())


consistency_weight = get_current_consistency_weight(iter_num // 150)
supervised_loss = loss_seg_dice
loss = supervised_loss + (1-consistency_weight) * (1000 * loss_at_kd) + consistency_weight * (1000 * loss_dc0 ) + 0.3 * loss_cer

写在最后

        在医学图像分析领域,准确和高效的分割技术对临床决策和患者护理至关重要。然而,传统的全监督学习方法通常需要大量的标注数据,这在医学图像中尤其难以获得。本文介绍的5分钟速成半监督医学图像分割方法,作为一种创新的解决方案,展示了在极短时间内实现高质量分割的可能性。这一方法不仅大大降低了对大量标注数据的依赖,而且通过智能算法的高效性和适应性,显著提升了分割精度和处理速度。

        随着半监督学习技术的不断进步,这种快速、高效的分割技术为医学图像分析打开了新的大门。它不仅能有效地辅助医生在繁忙的工作环境中迅速做出诊断决策,也为医学研究提供了更为可靠的数据支持。未来,我们可以期待这一技术在更多实际应用场景中的推广与深化,进一步提升医疗图像分析的智能化水平。5分钟速成半监督医学图像分割的诞生,标志着医学图像分析进入了一个新的时代,为医疗行业注入了新的活力和希望。

详细复现过程的项目源码、数据和预训练好的模型可从该文章下方附件获取

猜你喜欢

转载自blog.csdn.net/qq_53123067/article/details/141200783