解决办法 pred = torch.max(a,1,keepdim=True)[1]

pred = torch.max(a,1,keepdim=True)[1]
TypeError: torch.max received an invalid combination of arguments - got (torch.LongTensor, int, keepdim=bool), but expected one of:
 * (torch.LongTensor source)
 * (torch.LongTensor source, torch.LongTensor other)
 * (torch.LongTensor source, int dim)

      didn't match because some of the keywords were incorrect: keepdim


上面错误解决办法

把原先的代码

 pred = output.data.max(1, keepdim=True)[1]

改为

 pred = torch.max(output.data,1)[1] 


参考文章

torch.max

猜你喜欢

转载自blog.csdn.net/zhuoyuezai/article/details/80395862