RuntimeError: 1D target tensor expected, multi-target not supported

输出的标签是:[1., 0., 0., 0., 0.],

提取最大值所在的 index

 labels_ = torch.max(labels, 1)[1]    

### 返回最大值的索引

发布了234 篇原创文章 · 获赞 61 · 访问量 12万+

猜你喜欢

转载自blog.csdn.net/weixin_42528089/article/details/104836539