如何手写softmax函数防止数值溢出?

当我手写cross-entropy的时候,发现有时候竟然会出现error?整个数学计算过程没问题,主要问题就在于上溢出和下溢出,即当遇到极大或是极小的logits的时候,如果直接用公式按照exp的方式去进行softmax的话就会出现数值溢出的情况。为了解决这个问题,首先需要做的就是减去最大值,即:

logits = logits - torch.max(logits, 1)[0][:, None]

原理可以看这个链接:
https://zhuanlan.zhihu.com/p/29376573

但是我减去最大值之后还是会出现溢出,这个时候经过检查发现softmax后还是出现了0的情况,那再经过log函数之后就会变成负无穷,此时不要用手写的:

torch.log(F.softmax(logits, dim=-1))

而是直接使用torch自带的log_softmax,其做了一定的容错控制:

F.log_softmax(logits, dim=-1)

或者在使用log的时候加一个很小的数,防止出现0的情况。

猜你喜欢

转载自blog.csdn.net/weixin_42988382/article/details/123284103