分布式训练loda模型报错

一.分布式训练load模型时报错与网络参数匹配不一致的问题:

问题1:模型并非分布式模型,但是load保存的模型有.module参数

解决:

new_state_dict = {}
state_dict = checkpoint['state_dict']
for k, v in state_dict.items():
    name = k[7:]  # 去除前面的 ".module"
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

以上代码首先创建一个空字典new_state_dict,然后遍历原始state_dict中的每个键值对。去除每个键名".module"前缀,然后将处理过的state_dict加载到模型中。

这样做的目的是确保在load分布式训练中保存的模型时,将键名中的".module"前缀去除掉,以匹配模型的结构。

问题2:模型是分布式模型,但是load保存的模型没有.module参数

解决:

new_state_dict = {}
state_dict = checkpoint['state_dict']
for k, v in state_dict.items():
    name = 'module.' + k  # 在键名前添加 "module."
    new_state_dict[name] = v
 model.load_state_dict(new_state_dict)

以上代码首先创建一个空字典new_state_dict,然后遍历原始state_dict中的每个键值对。在每个键名前面添加"module."前缀,然后将处理过的state_dict加载到模型中。

这样做的目的是确保在分布式训练中加载模型时,将键名中的"module."前缀添加回来,以匹配模型的结构。

注:主要关注一下在torch.save保存模型时是以怎样的方式保存的,

猜你喜欢

转载自blog.csdn.net/m0_62278731/article/details/134749627
今日推荐