RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be same

在这里插入图片描述
简而言之,就是输入类型是对应cpu的torch.FloatTensor,而模型网络的超参数却是用的对应gpu的torch.cuda.FloatTensor
一般是在本地改代码的时候,忘记将forward(step)的一些传递的参数to(device)导致的,本人就是如此,哈哈。

解决方法如下:

以下是针对每个batch解压数据的时候,对其每类数据to(device),一般在for batch in self.train_data(或者train_dataloader这个循环中)

if self.args.device != 'cpu':
    # batch = tuple(t.to(self.args.device) for t in batch)
    batch = (tup.to(self.args.device) if isinstance(tup, torch.Tensor) else tup for tup in batch)

反之同理

如果RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be same
就是将模型/网络采用to(device)即可。
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

猜你喜欢

转载自blog.csdn.net/weixin_42455006/article/details/125268319