一.分布式训练load模型时报错与网络参数匹配不一致的问题:
问题1:模型并非分布式模型,但是load保存的模型有.module参数
解决:
new_state_dict = {}
state_dict = checkpoint['state_dict']
for k, v in state_dict.items():
name = k[7:] # 去除前面的 ".module"
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
以上代码首先创建一个空字典new_state_dict
,然后遍历原始state_dict中的每个键值对。去除每个键名".module"前缀,然后将处理过的state_dict加载到模型中。
这样做的目的是确保在load分布式训练中保存的模型时,将键名中的".module"前缀去除掉,以匹配模型的结构。
问题2:模型是分布式模型,但是load保存的模型没有.module参数
解决:
new_state_dict = {}
state_dict = checkpoint['state_dict']
for k, v in state_dict.items():
name = 'module.' + k # 在键名前添加 "module."
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
以上代码首先创建一个空字典new_state_dict
,然后遍历原始state_dict中的每个键值对。在每个键名前面添加"module."前缀,然后将处理过的state_dict加载到模型中。
这样做的目的是确保在分布式训练中加载模型时,将键名中的"module."前缀添加回来,以匹配模型的结构。
注:主要关注一下在torch.save保存模型时是以怎样的方式保存的,