PyTorch:expected scalar type Float but found Double

这个问题很明显就是网络内的参数类型不同意;
修改:
在前面添加:

torch.set_default_tensor_type(torch.DoubleTensor)

或者,在运行网络前添加:

net = net.double()

猜你喜欢

转载自blog.csdn.net/sazass/article/details/109725458