pytorch 参数注册问题

在实现deepfm是进行特征编码时遇到RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected type torch.FloatTensor but got torch.cuda.FloatTensor问题,但模型和输入都已经to(device),经检查发现nn.ModuleList nn.ModuleDict的参数无法通过model.to(device)自动注册,要手动注册。

猜你喜欢

转载自www.cnblogs.com/yutingmoran/p/11884422.html