Key中出现“module”的情况

问题:
unexpected key “module. model.weight”, 多了一个module

原因:
预训练模型是由多gpu进行训练,加载到单gpu会出现key值不匹配的现象

解决方法(两种):

第一种:使用切片的方式取key值

# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt')  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
    new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。

第二种:直接使用空白代替key值

model.load_state_dict({
    
    k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})
 
# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。

猜你喜欢

转载自blog.csdn.net/weixin_36411839/article/details/108975280
今日推荐