RuntimeError: Error(s) in loading state_dict for DeepLabV3: Missing key(s) in state_dict: 的解决方法

深度学习,模型预测时,加载权重遇到一下这种问题:

RuntimeError: Error(s) in loading state_dict for DeepLabV3:

Missing key(s) in state_dict: "classifier.aspp.convs.1.0.weight", "classifier.aspp.convs.2.0.weight", "classifier.aspp.convs.3.0.weight", "classifier.classifier.0.weight".

Unexpected key(s) in state_dict: "classifier.aspp.convs.1.0.body.0.weight", "classifier.aspp.convs.1.0.body.1.weight", "classifier.aspp.convs.2.0.body.0.weight", "classifier.aspp.convs.2.0.body.1.weight", "classifier.aspp.convs.3.0.body.0.weight", "classifier.aspp.convs.3.0.body.1.weight", "classifier.classifier.0.body.0.weight", "classifier.classifier.0.body.1.weight".

解决方法为:

出错代码:

    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"],False)
        model = nn.DataParallel(model)
        model.to(device)
        print("Resume model from %s" % opts.ckpt)
        del checkpoint

出错原因:

原因:训练模型保存的方式和加载的模型不一致

解决方案:

if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"],False)
        model = nn.DataParallel(model)
        model.to(device)
        print("Resume model from %s" % opts.ckpt)
        del checkpoint

猜你喜欢

转载自blog.csdn.net/qq_42514371/article/details/128839541