一、问题描述
多个GPU 训练,保存时没有加module , 导致加载模型时报错。正确写法应该如下:
# save model
if num_gpu == 1:
torch.save(model.module.state_dict(), os.path.join(opt.outf, 'net.pth'))
else:
torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))
二、解决方法
load 模型时,删除多余的module. 可以打印下面代码中的pth , 查看问题症结
具体代码如下:
print("load pre_training weight. ")
pth = torch.load(os.path.join(opt.outf, 'net.pth'))
new_state_dict = OrderedDict()
for k, v in pth.items():
name = k[7:] # remove 'module'
new_state_dict[name]=v
net.load_state_dict(new_state_dict)