使用CrossEntropy是常见到如上错误:
例如:
criterion = nn.CrossEntropyLoss()
loss = criterion(logit, target.long())
其中,logit: torch.Size([4, 31, 256, 256]); target: [4, 256, 256, 1]
就会出现1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size的错误
更多信息可参考:https://www.computationalimaging.cn/2020/01/1only-batches-of-spatial-targets.html
错误原因:
Pytorch的CrossEntropy官方文档截图如下:
显然,其中N=4, C=31, d1=d2=256, 故target中的1是多余的,是报错的原因。
解决方案:
loss = criterion(logit, torch.squeeze(target).long())
将target变为[4, 256, 256]即可。