[论文阅读] BoostMIS: Boosting Medical Image Semi-supervised Learning with Adaptive Pseudo Labeling

[论文地址] [代码] [CVPR 22]

Abstract

在本文中,我们提出了一个名为BoostMIS的新型半监督学习(SSL)框架,它结合了自适应伪标签和信息性主动注释,以释放医学图像SSL模型的潜力:(1)BoostMIS可以根据当前的学习状态,自适应地利用集群假设和未标记数据的一致性规范化。这一策略可以自适应地生成由任务模型预测转换而来的单次 "硬 "标签,以更好地进行任务模型训练。(2) 对于未选择的低置信度的未标记图像,我们引入主动学习(AL)算法,通过利用虚拟对抗扰动和模型的密度感知熵,找到有信息的样本作为标注候选。这些有信息量的候选样本随后被送入下一个训练周期,以便更好地进行SSL标签传播。值得注意的是,自适应伪标签和信息量大的主动标注形成了一个学习闭环,它们相互协作,促进了医学图像SSL。为了验证所提方法的有效性,我们收集了一个转移性硬膜外脊髓压迫(MESCC)的数据集,旨在优化MESCC的诊断和分类,以改善专家转诊和治疗。我们在MESCC和另一个公共数据集COVIDx上进行了BoostMIS的广泛实验研究。实验结果验证了我们的框架对不同医学图像数据集的有效性和通用性,与各种最先进的方法相比,有了明显的改善。

I. Introduction

本文是一个比较标准的将主动学习与半监督学习相结合的工作。即,利用主动学习不断选择半监督中的标注集,从而做到"boost"半监督的作用,流程如下:
在这里插入图片描述
核心点在于半监督学习中已标注样本的选择是会对总体性能有较大影响的,这样才需要主动学习算法来选择"更好的标注集"。总体来说,本文的框架是非常简单的,技术并不复杂,不过故事层面讲的很有意思。不同于自然图像分割任务,医学任务的图像是高度相似的。由于样本数量稀少,训练困难,因此很难找到足够高质量的伪标签以供学习。此外,一些低置信度的预测结果可能表明该样本值得标注。比方说,正是因为网络尚未学到这张图片中所包含的一些有价值的特征,所以才对其产生了错误的预测。

本文的AL+SSL流程如下:
在这里插入图片描述
整体步骤还是较多的。从上往下看:

  • 1)首先是Medical Image Task Model。随机初始化少量的初始样本,并将其用于分割模型的训练。需要注意的是,和绝大多数工作类似,训练样本进行了一些简单的数据增强以提升模型的鲁棒性,这里称之为弱增强(Weakly Augmentation)。
  • 2)接着是Consistency-based Adaptive Label Propagator。同许多半监督思路一样,将未标注样本送入模型产生预测结果,将其中高置信度的样本视为真值准备用于训练。注意在这里,本文提出了一种自适应阈值的方法,以在网络的不同阶段控制样本选择的质量;此外,还利用一个辅助网络减少伪标签噪声对模型训练的干扰,预防模型在self training的过程中崩溃。
  • 3)此外是Adversarial Unstability Selector。通过扰动,寻找位于模型决策边界的有价值样本。
  • 4)此外是Balanced Uncertainty Selector。利用密度熵,寻找高价值样本。

整体思路是,对于高置信度的样本,直接将其伪标签视为真值用于训练;对于低置信度的样本,利用主动学习算法挖掘其中的高价值样本,进行人工标注以用于训练。核心亮点为(可能是)首个AL+SSL医学分割工作。

II. Consistency-based Adaptive Label Propagator

一般来说,半监督的一个基本方法是伪标签。即,手工给定一个置信度,如果预测softmax结果中的最大类概率高于这个置信度,我们就认为这个结果和GT差不多,可以拿来直接作为标签。在本文中,作者指出,由于网络的学习能力是动态变化的(逐渐变强),因此一个固定的置信度阈值可能造成网络早期难以选择任何伪标签,或者在网络后期选择到大量的带噪声的伪标签。因此,本文提出使用一个自适应(逐渐增大)的阈值以供伪标签选择。

直接看公式,在网络训练至第 t t t步时的自适应阈值(Adaptive threShold, AS) ϵ t \epsilon_{t} ϵt定义如下:
ϵ t = { α ⋅ Min ⁡ { 1 ,  Count  ϵ t  Count  ϵ t − 1 } + β ⋅ N A 2 K ,  if  t < T max ⁡ α + β ,  otherwise  \epsilon_{t}= \begin{cases}\alpha \cdot \operatorname{Min}\left\{1, \frac{\text { Count }_{\epsilon_{t}}}{\text { Count }_{\epsilon_{t-1}}}\right\}+\frac{\beta \cdot N_{A}}{2 K}, & \text { if } t<T_{\max } \\ \alpha+\beta, & \text { otherwise }\end{cases} ϵt={ αMin{ 1, Count ϵt1 Count ϵt}+2KβNA,α+β, if t<Tmax otherwise  首先看这个otherwise。当 t ≥ T max ⁡ t \geq T_{\max } tTmax时,阈值锁定为固定的 α + β \alpha+\beta α+β。这个的意思是当网络训练到一定程度时,其表征学习已经较为稳定不会发生很大变化了,此时直接使用传统的手工阈值即可。需要注意的是,这里的"一定程度" T max ⁡ T_{\max} Tmax,以及 α \alpha α β \beta β都是手工指定的超参数,和固定阈值玩法相同。

而当网络表征波动较大( t < T max ⁡ t<T_{\max} t<Tmax)时,此时阈值就是自适应的了。而这个公式里还涉及到一个 C o u n t Count Count函数,因此我们首先观察其定义: Count ⁡ ϵ t = ∑ i = 1 N u 1 ( P m ( p i ∣ A w ( u i ) ) > α + β ) \operatorname{Count}_{\epsilon_{t}}=\sum_{i=1}^{N_{u}} \mathbb{1}\left(P_{m}\left(\mathbf{p}_{i} \mid A_{w}\left(\mathbf{u}_{i}\right)\right)>\alpha+\beta\right) Countϵt=i=1Nu1(Pm(piAw(ui))>α+β) 其中, N u N_{u} Nu为所有未标注样本(伪标签)的数量, p i \mathbf{p}_{i} pi为第 i i i个未标注样本所预测得到的伪标签, A w ( u i ) A_{w}(\mathbf{u}_{i}) Aw(ui)表示经过一个弱数据增强的未标注样本, P m ( p i ∣ A w ( u i ) ) P_{m}(\mathbf{p}_{i} | A_{w}(\mathbf{u}_{i})) Pm(piAw(ui))表示该未标注样本所预测得到伪标签的置信度。至此,可以发现这个 C o u n t Count Count函数记录的是当前学习阶段下能够满足(较高的)手工阈值的伪标签的数量。

再回到上面那个公式。公式左边为: α ⋅ Min ⁡ { 1 ,  Count  ϵ t  Count  ϵ t − 1 } \alpha \cdot \operatorname{Min}\left\{1, \frac{\text { Count }_{\epsilon_{t}}}{\text { Count }_{\epsilon_{t-1}}}\right\} αMin{ 1, Count ϵt1 Count ϵt}也就是说,如果  Count  ϵ t ≥  Count  ϵ t − 1 \text { Count }_{\epsilon_{t}} \geq \text { Count }_{\epsilon_{t-1}}  Count ϵt Count ϵt1,意味着网络在进一步的学习过程中,能够产生更多满足手工阈值的"高质量伪标签",此时 α \alpha α就乘以了一个大于1的系数,提高选择门槛,使得选择的标签质量更高,舍弃到一些"刚好满足阈值"的"相对低质量"样本;反之,如果  Count  ϵ t <  Count  ϵ t − 1 \text { Count }_{\epsilon_{t}} < \text { Count }_{\epsilon_{t-1}}  Count ϵt< Count ϵt1 α \alpha α不变,维护基本的选择门槛。

公式右边为: β ⋅ N A 2 K \frac{\beta \cdot N_{A}}{2 K} 2KβNA 也就是 β \beta β乘了一个系数 N A 2 K \frac{N_{A}}{2K} 2KNA K K K是一个人工超参,没啥说的, N A N_A NA则是已标注样本数。显然, N A N_A NA是会逐渐增大的,也就是 β \beta β的影响会逐渐变大。不过需要特别注意的一点是,由于主动学习的约束, N A 2 K \frac{N_{A}}{2K} 2KNA恒小于1,因此 β \beta β所乘的这个系数是逐渐从0上升至1的。

至于Consistency Regularization,由于伪标签训练是一个self training的过程,可能出现模型崩溃,因此引入了一个辅助网络,该网络以强数据增强(Strong Augmentation)的样本为输入,由伪标签进行监督。由于此时输入发生的了较为剧烈的变化,因此可以迫使网络学习图像的一些深层次的具有区分性的特征,减少对伪标签中噪声的拟合。这一思路与FixMatch非常接近,因此本文在此处基本是一笔带过。

III. Adversarial Unstability Selector

有价值的样本可以分为两种,unstable与uncertain,本节介绍unstable。大致思想为,对于未标注样本的表征,我们人为地给其加入一定的噪声。将原表征与噪声处理后的表征送入输出层以得到输出结果。如果这两者的softmax结果差异较大(用KL散度衡量),说明该样本不稳定,价值较高。

IV. Balanced Uncertainty Selector

本文对entropy进行了简单的改进,以估算不确定性。简单使用entropy会容易引入离群点、异常点、重复点,从而导致性能不佳。为此本文搞出了个density-aware entropy: Ent ⁡ ( u i u ; θ S ) = Ent ⁡ ′ ( u i u ; θ S ) ( 1 M ∑ j = 1 M Sim ⁡ ( u i u , u j u ) ) \operatorname{Ent}\left(\mathbf{u}_{i}^{u} ; \theta_{S}\right)=\operatorname{Ent}^{\prime}\left(\mathbf{u}_{i}^{u} ; \theta_{S}\right)\left(\frac{1}{M} \sum_{j=1}^{M} \operatorname{Sim}\left(\mathbf{u}_{i}^{u}, \mathbf{u}_{j}^{u}\right)\right) Ent(uiu;θS)=Ent(uiu;θS)(M1j=1MSim(uiu,uju)) Ent ⁡ ′ ( u i u ; θ S ) \operatorname{Ent}^{\prime}\left(\mathbf{u}_{i}^{u} ; \theta_{S}\right) Ent(uiu;θS)为原始的熵,在其基础上乘了一个系数 1 M ∑ j = 1 M Sim ⁡ ( u i u , u j u ) \frac{1}{M} \sum_{j=1}^{M} \operatorname{Sim}(\mathbf{u}_{i}^{u}, \mathbf{u}_{j}^{u}) M1j=1MSim(uiu,uju)。该系数的含义为,对于该样本 u i u \mathbf{u}_{i}^{u} uiu,计算其与其他点的相似性。如果相似性高,则说明该样本是比较具有代表性的(而非异常点离群点),更应该被选择。

猜你喜欢

转载自blog.csdn.net/qq_40714949/article/details/124999241