torch.cuda.LongTensor but found type torch.cuda.FloatTensor for argument #2 'target'的一种可能原因

版权声明:转载注明出处 https://blog.csdn.net/york1996/article/details/84189741

可能是在使用交叉熵损失函数的时候,target需要是整数,才能转化成索引值,进而进行one-hot编码。

输出一下target的张量,可以看到每个值都后面有一个点.比如5.这样,应该表示的就是浮点类型的值。

这个时候需要target=target.long()执行一下类型转换。

猜你喜欢

转载自blog.csdn.net/york1996/article/details/84189741