问题:
unexpected key “module. model.weight”, 多了一个module
原因:
预训练模型是由多gpu进行训练,加载到单gpu会出现key值不匹配的现象
解决方法(两种):
第一种:使用切片的方式取key值
# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt') # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。
第二种:直接使用空白代替key值
model.load_state_dict({
k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})
# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。