keras多任务多loss回传的思考

如果有一个多任务多loss的网络,那么在训练时,loss是如何工作的呢?
比如下面:

model = Model(inputs = input, outputs = [y1, y2])
l1 = 0.5
l2 = 0.3
model.compile(loss = [loss1, loss2], loss_weights=[l1, l2], ...)

其实我们最终得到的loss为

final_loss = l1 * loss1 + l2 * loss2

我们最终的优化效果是最小化final_loss。
问题来了,在训练过程中,是否loss2只更新得到y2的网络通路,还是loss2会更新所有的网络层呢?
此问题的关键在梯度回传上,即反向传播算法。

在这里插入图片描述
对于x1参数的更新:
f i n a l l o s s x 1 = l 1 l o s s 1 X 1 + l 2 l o s s 2 X 2 \frac{\partial finalloss}{\partial x1}=\frac{l1 *\partial loss1}{\partial X1} + \frac{l2 *\partial loss2}{\partial X2}
对于x2参数的更新:
f i n a l l o s s x 2 = l 2 l o s s 2 X 2 \frac{\partial finalloss}{\partial x2}= \frac{l2 *\partial loss2}{\partial X2}
对于x2参数的更新:
f i n a l l o s s x 1 = l 1 l o s s 1 X 1 \frac{\partial finalloss}{\partial x1}= \frac{l1 *\partial loss1}{\partial X1}
所以loss1只对x1和x2有影响,而loss2只对x1和x3有影响。

参考:https://stackoverflow.com/questions/49404309/how-does-keras-handle-multiple-losses

猜你喜欢

转载自blog.csdn.net/m0_37477175/article/details/85163362