pytorch 多卡并行载入部分网络模型

pytorch 多卡并行载入部分网络模型

我们在做深度学习的时候经常会使用预训练的模型。但是一旦自己修改了网络架构,就无法load pretrained model。 因为模型文件保存的参数,有一部分是不需要的,或者有一部分参数是缺失的。

为了在这种情况下,成功导入模型,我们需要如下操作

操作的前提是我们存在已保存的模型参数

model = Net()
torch.save(model.state_dict(),'xxx.path')

接下来就好办了

 device = torch.device("cuda:2" if args.cuda else "cpu")

 #Try to load models
 model = DGCNN(args).to(device)
 print(str(model))

 device_ids = [2,3]
 model = nn.DataParallel(model,device_ids=device_ids) #使用2,3号显卡进行训练
 save_model = torch.load('model.t7')
 
 model_dict =  model.state_dict()

 state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
 print(state_dict.keys())  
 
 model_dict.update(state_dict)
 model.load_state_dict(model_dict)

update之后,model_dict和state_dict中具有相同键的值已经同步了。
可以开始愉快的训练了!

参考

https://blog.csdn.net/qq_34914551/article/details/87871134

发布了131 篇原创文章 · 获赞 6 · 访问量 6919

猜你喜欢

转载自blog.csdn.net/Orientliu96/article/details/104583251