半监督学习——FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

论文:https://arxiv.org/abs/2001.07685

代码:https://github.com/google-research/fixmatch

1. 论文题目与摘要

                               FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

       摘要:半监督学习有效的利用没有标注的数据,从而提高模型的精度。这篇论文,我们将有效的结合两种常见的半监督学习方法:一致性正规化技术和伪标签技术。我们的算法叫做FixMatch,首先把没有标签的图片进行轻微的数据增强,用模型对怎强后的图片进行预测,从而生成为标签。对于每张没有标签的图片,当模型的预测得分高于一定的阈值时,伪标签才起作用。模型预测伪标签的同时,将同样的图片进行强烈的数据增强送入网络,计算损失。虽然方法看起来简单,但是FixMatch在从多的半监督学习方法中达到了最好的效果。仅用了250张标注数据,在CIFAR-10数据集上达到了94.93%的准确率;仅用了40张标注数据,在CIFAR-10数据集上达到了88.61%的准确率(每个类别只取了4张标注数据);因为作者做了很多消融实验,说明不同因素对半监督学习效果的影响,最终FixMatch这种半监督学习方法获得成功。我们的代码已经开源:https://github.com/google-research/fixmatch.

 2. 算法主要流程       

Caption

             首先,图片进行轻微的数据增强,然后输入网络进行预测,生成独热编码的为标签。然后,把同样的图片进行强烈的数据增强,得到预测特征。如果轻微数据增强的预测得分大于一定的阈值,那么生成的为标签就和强烈数据增强的特征计算交叉熵损失。整个过程如上图所示:

3. 实现细节

            从整体来看,FixMatch算法是两种半监督学习算法的简单结合,即一致性正则化技术和伪标签技术。

            FixMatch的损失函数有两部分组成:有标签的图片用有监督的损失Ls,没有标签的图片用无监督的损失Lu, 两个损失都是标准的交叉熵损失。

            首先,看看有监督的损失函数,标准的交叉熵损失函数:

           再看看对于没有标签图片的处理:首先得到伪标签,如果伪标签的得分大于一定的阈值(τ,论文中的阈值取0.95),那么,就用该伪标签和强烈数据增强获得的特征计算交叉熵损失:           

           最后,FixMatch的损失函数为:Ls + λ * Lu, 其中λ是一个超参数,用来平衡两个损失函数的,论文中λ=1。

           论文中超参数的设置如下:

            其中:μ为无标签图片和有标签图片的比例。

            模型训练的伪代码如下图所示:

           当然,作者还做了很多消融实验,例如:调节学习率、选择优化器等等。作者的工作量还是挺大的,但创新点就那么多。

4. 总结

           作者提出了一种简单的半监督学习算法:FixMatch,该半监督学习算法在多个数据集上达到了最先进的的结果。FixMatch搭建了low-label semi-supervised learning 和 few-shot learning的联系,甚至聚类算法。作者每一类仅用一张有标签图片,就获得了很高的准确率。对于有标签和无标签的图片,由于Fixmatch用标准的交叉熵损失函数,所以Fixmatch训练工程仅用几行代码就可以完成。

以上是博主对论文的理解,如需讨论,请留言!

猜你喜欢

转载自blog.csdn.net/Guo_Python/article/details/107867272