Pytorch distributed 多卡并行载入模型
前面的博客介绍了pytorch多卡distribute的方法,这次来介绍下如何载入模型。
目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。
大部分情况下,我们在测试时不需要多卡并行计算。所以,我在测试时只使用单卡。
from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device) #自己的模型
state_dict = torch.load(args.model_path) #存放模型的位置
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict (new_state_dict)