[Linux服务器 错误] RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM

这是因为数据类型不一致造成的。我出现的原因是因为Numpy默认产生的数据是float64, 而之前的数据都是float32。所以改变构造数据的代码如下
A = np.ones([1,5], dtype = float32)

另一个错误“RuntimeError: expected type torch.cuda.DoubleTensor but got torch.cuda.FloatTensor”是由于没有把tensor变成cuda并且数据类型不一致
添加 .cuda() 并由上文一样检查数据类型

猜你喜欢

转载自blog.csdn.net/qq_45347185/article/details/108965424