python工具方法 33 基于lossFusion类轻松实现多个loss的集成

在进行深度学习实践中,经常会使用到混合loss,无论是图像分类、还是语义分割。在图像分类中,经常将交叉熵与focal loss或者ghm loss进行混合;在语义分割中经常将交叉熵、dice loss、focal loss、iou loss、MS-SSIM loss等进行混合。使用混合loss通常能获得更为稳定的性能,且基于混合loss对评价指标(dice、iou)的优化,甚至能进一步提升性能。此外,如focal loss、ghm loss解决了样本不平衡的问题,也能提升性能。

1、现有问题

如果将模型训练代码中的单一loss更改为多个loss加权混合,对代码改动量较大。且,在语义分割模型中有着深度监督的概念(模型输出多个尺度的预测结果,如hrnet、unet3+等网络),这种多个out的输出,不能直接使用ce loss,dice loss等,需要对多个输出进行遍历,这使得计算loss的代码更为复杂了。
为了解决这一问题,博主提出了lossFusion类的使用(支持paddle、pytorch、tensorflow等框架)。通过使用该类,不管是单个loss,还是多个加权loss。不管模型输出的结果是一个还是多个尺度,都只需要相同的代码即可计算loss。

2、解决方案

具体实现代码如下所示,lossFusion类支持传入两个参数,分为为loss_list和loss_weights。loss_list为list对象,支持传入1个或多个loss;loss_weights也为list对象,支持传入每个loss对应的权重。博主通过改写__call__函数实现了直接通过对象名调用计算loss&

猜你喜欢

转载自blog.csdn.net/a486259/article/details/125956395
33