多任务学习(MTL) --- 知识小结+实现

  • What

    • 通常一个模型训练时有多个目标函数loss同时训练就可以叫多任务学习,预测时输出多个结果的模型就是多任务模型

在这里插入图片描述

  • Why

    • 工业界实际应用时维护单个模型比同时维护k个模型更方便,成本更低
    • 提高泛化性能
  • How

    • 思路1:手工加权平均

      • 基本思想:对于多任务的loss,最简单的方式是直接将loss函数对于每个任务的loss进行加权
      • 这种方式手工设置权重,模型性能对权重的选择非常敏感,而且loss的权重作为超参数进行调参很不方便,更好的加权方式应该是自适应动态调整的

在这里插入图片描述

  • 思路2:动态加权平均

    • 基本思想:不同任务难易程度不同,学习速度不同,对于这点可以针对不同任务设置不同学习率,但是更好的思路是动态调整让各个任务以相近的速度学习,这就是DWA(Dynamic Weight Averaging — 动态加权平均)算法的核心思想

    • loss下降快的任务,则权重会变小;反之权重会变大

在这里插入图片描述

  • 思路3:动态任务优先级

    • 基本思想:难学的任务给予更高的权重

    • KPI高的任务,学习起来比较简单,则权重会变小;反之,难学的任务权重会变大

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NkYLHbcj-1619514704411)(/home/goodix/share/2021-04-27 16-55-57屏幕截图.png)]

  • 思路4:不确定性加权方法

    • 基本思想:难学的任务给予更小的权重使得整体的多任务模型的训练更加顺畅和有效(和思路3相反。。)

    • 前提概念:认知不确定性和偶然不确定性

      • 认知不确定性(epistemic):指的是由于缺少数据导致的认知偏差。当数据很少的时候,训练数据提供的样本分布很难代表数据全局的分布,导致模型训练学偏。这种不确定性可以通过增加数据来改善。
      • 偶然不确定性(aleatoric):指的是由于数据本身,或者任务本身带来的认知偏差。偶然不确定性有个特点,其不会随着数据量增加而改善结果,数据即使增加,偏差仍然存在。
      • 偶然不确定性可以分为两种情况:
        1. 数据依赖型或异方差。在进行数据标注的时候的误标记、错标记等,这些错误的数据也会造成模型预测偏差;
        2. 任务依赖型或同方差。这个指的是,同一份数据,对于不同的任务可能会导致不同的偏差
    • 这种思路希望基于偶然不确定性(aleatoric)中的同方差不确定性来进行建模,以两个任务为例,最终推导后的loss函数:
      在这里插入图片描述

      • 其中,sigma1和sigma2是两个任务中,各自存在的不确定性
      • sigma越大,任务的不确定性越大,则任务的权重越小,即噪声大且难学的任务权重会变小,简单的任务权重变大
    • 原论文《Multi-task learning using uncertainty to weigh losses for scene geometry and semantics》

      • github上用pytorch的实现:https://github.com/Mikoto10032/AutomaticWeightedLoss
      import torch
      import torch.nn as nn
      
      class AutomaticWeightedLoss(nn.Module):
          """automatically weighted multi-task loss
          Params:
              num: int,the number of loss
              x: multi-task loss
          Examples:
              loss1=1
              loss2=2
              awl = AutomaticWeightedLoss(2)
              loss_sum = awl(loss1, loss2)
          """
          def __init__(self, num=2):
              super(AutomaticWeightedLoss, self).__init__()
              params = torch.ones(num, requires_grad=True)
              self.params = torch.nn.Parameter(params)
      
          def forward(self, *x):
              loss_sum = 0
              for i, loss in enumerate(x):
                  loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
              return loss_sum
      
      if __name__ == '__main__':
          awl = AutomaticWeightedLoss(2)
          print(awl.parameters())
      
      • 应用示例:
      from torch import optim
      from AutomaticWeightedLoss import AutomaticWeightedLoss
      
      model = Model()
      
      awl = AutomaticWeightedLoss(2)	# we have 2 losses
      loss_1 = ...
      loss_2 = ...
      
      # learnable parameters
      optimizer = optim.Adam([
                      {'params': model.parameters()},
                      {'params': awl.parameters(), 'weight_decay': 0}
                  ])
      
      for i in range(epoch):
          for data, label1, label2 in data_loader:
              # forward
              pred1, pred2 = Model(data)	
              # calculate losses
              loss1 = loss_1(pred1, label1)
              loss2 = loss_2(pred2, label2)
              # weigh losses
              loss_sum = awl(loss1, loss2)
              # backward
              optimizer.zero_grad()
              loss_sum.backward()
              optimizer.step()
      

猜你喜欢

转载自blog.csdn.net/hechao3225/article/details/116205265