RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.DoubleTensor) should be the

出错位置: 

train_label_batch = torch.from_numpy(train_label_batch)

解决办法:将数据类型转换成FloatTensor即可,如下,加一行代码Tensor.type(torch.FloatTensor)

train_label_batch = torch.from_numpy(train_label_batch)
train_label_batch = train_label_batch.type(torch.FloatTensor)  # 转Float
train_label_batch = train_label_batch.cuda()  # 转cuda

猜你喜欢

转载自blog.csdn.net/jizhidexiaoming/article/details/82502280