关于网络结构输出层加了softmax后,loss训练不下降的问题

小白第一次写,很多地方会显得比较生疏……

———————————————————————————————————————————

1.13 更新:

原因是nn.crossentropy()内部封装了logsoftmax函数,再用一次softmax的话会导致概率过早进入不能被训练的饱和状态(即假概率逼近0,真概率逼近1)。

所以解决措施是:

1. 去掉网络结构里的softmax层

2. 重新实现cross-entropy函数

尝试后,这两种方法都可以成功训练数据。

附: torch在nn.crossentropy()内封装logsoftmax的原因

主要是为了防止数值溢出,通过logsoftmax的公式推导可以得出一个避免进行log(0) / exp(∞)的运算
​​​​​​​详见:3.7. softmax回归的简洁实现 — 动手学深度学习 2.0.0 documentation (d2l.ai)

———————————————————————————————————————————

首先记录的是为什么多分类问题中需要用softmax+crossentropyloss交叉熵函数的配合。

交叉熵本质上是衡量输入与输出之间的区别,其值越小,则区别越小,毕竟从名字上看,熵就是衡量一种混乱程度嘛。

在多分类问题中,我们通常会将输入的标签编码为one-hot形式,形如[0,0,1,0,0],则此数据属于数据集中的第三类。这是交叉熵函数的输入之一。

在我们训练的神经网络最后一层,我们通常会使用一个全连接层,其输出维度就是我们类的总个数。假定输出为[-10,3,20,9,-1],那么此时我们需要一种方法去衡量这个输出与输入one-hot标签的区别,作为loss提供给network进行反向传播。当然,这种方法(函数)需要满足端到端的可导性质。

那么这个时候,softmax函数看上去是不错的,因其可以将输出数据转换为概率形式,并且让大的数据更大,小的数据更小,以此强化输出的类别特征。

这个时候我们又发现,交叉熵函数,其形式刚好又是按照定义信息量的概率类型。那两者这一配合,岂不完美。

但此处我翻看pytorch官方文档时,发现官方文档又告诉我们,使用pytorch集成的nn.crossentropyloss()函数时,最好不要用Softmax处理为概率类型的输出。

我又看仔细了其函数定义,实际上是因为这里面的nn.crossentropyloss()函数已经集成了一种logSoftmax()的方法,所以相当于已经给你转成了概率格式,不需要你自己再转一次了(我是这么理解的)官方文档说的理由我并不是很看得懂,希望有大佬能解惑。

 另外,官方文档上的事例,输入也是不需要转换one-hot格式,直接用类别标签就行了。

那么,至此我看懂了交叉熵函数的公式,原理以及其在pytorch下的用法,接下来是时候去解决标题中的问题了。

标题所述问题解决

我所遇到的问题是这样。

首先我处理的是一个17分类的问题,数据有190维度,10515样本量。我制作的dataloader类的数据标签为单标签格式,也并没有处理为one-hot编码。

我的网络结构是一个190*1000*17的全连接神经网络,激活函数选择ReLu,优化方法SGD,损失函数nn.crossentropyloss()。但是如果我在最后一层加了一个Softmax激活函数的话,训练精度就完全不能上升。如图所示:

 而去掉网络结构中的全连接层后,训练集准确率可以在几轮内就训练到75左右。

这让我非常疑惑。因为官方文档中所述的不用类别概率作为交叉熵函数输入的理由也仅仅是不推荐,我并不认为这个原因会导致如此严重的后果,甚至完全训练不了。

于是我想到查看hidden层的参数梯度,发现过了10个epoch左右,未加softmax的梯度值依然处在一个e-3的水平,而加了softmax则直接低到了e-7。。。

究其原因,softmax给整个网络在链式求导中多了一层梯度,而由于ak经过softmax概率归一化,其值大多很低,大多在e-2次方数量级。这样相当于直接给梯度增加了接近e-3~-4次方的梯度,所以导致了梯度消失的现象。

至此也算是解决了个问题吧!神经网络玄学太多,我只想抓住那么一丢丢可以解释的现象理解一下,帮助我建立一写insight吧hhh


 

参考:

(9条消息) Pytorch中的CrossEntropyLoss()函数案例解读和结合one-hot编码计算Loss_梦坠凡尘的博客-CSDN博客

(9条消息) 超详细的softmax的反向传播梯度计算推导_深肚学习的博客-CSDN博客_softmax反向传播公式CrossEntropyLoss — PyTorch 1.13 documentation

猜你喜欢

转载自blog.csdn.net/Blossomers/article/details/124080960