深度学习,模型预测时,加载权重遇到一下这种问题:
RuntimeError: Error(s) in loading state_dict for DeepLabV3:
Missing key(s) in state_dict: "classifier.aspp.convs.1.0.weight", "classifier.aspp.convs.2.0.weight", "classifier.aspp.convs.3.0.weight", "classifier.classifier.0.weight".
Unexpected key(s) in state_dict: "classifier.aspp.convs.1.0.body.0.weight", "classifier.aspp.convs.1.0.body.1.weight", "classifier.aspp.convs.2.0.body.0.weight", "classifier.aspp.convs.2.0.body.1.weight", "classifier.aspp.convs.3.0.body.0.weight", "classifier.aspp.convs.3.0.body.1.weight", "classifier.classifier.0.body.0.weight", "classifier.classifier.0.body.1.weight".
解决方法为:
出错代码:
if opts.ckpt is not None and os.path.isfile(opts.ckpt):
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"],False)
model = nn.DataParallel(model)
model.to(device)
print("Resume model from %s" % opts.ckpt)
del checkpoint
出错原因:
原因:训练模型保存的方式和加载的模型不一致
解决方案:
if opts.ckpt is not None and os.path.isfile(opts.ckpt):
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"],False)
model = nn.DataParallel(model)
model.to(device)
print("Resume model from %s" % opts.ckpt)
del checkpoint