官方宣称,保存和加载模型参数有两种方式:
方式一:
torch.save(net.state_dict(),path)
功能:保存训练完的网络的各层参数(即weights和bias)
其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)
net2.load_state_dict(torch.load(path))
功能:加载保存到path中的各层参数到神经网络
注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数
方式二:
torch.save(net,path)
功能:保存训练完的整个网络模型(不止weights和bias)
net2=torch.load(path)
功能:加载保存到path中的整个神经网络
经过自己的尝试之后,发现这种方式只能保存nn.Module模块中的参数,如果想要保存global_step之类的信息,需要一些小技巧:
state = {'net':model.state_dict(),
'optimizer':optim.state_dict(),
'global_step':global_step,
'best_acc':best_acc,
'best_step':best_step}
torch.save(state, args.saved_model_path)
然后加载的时候,用如下的方式加载:
checkpoint = torch.load(args.saved_model_path)
model.load_state_dict(checkpoint['net'])
optim.load_state_dict(checkpoint['optimizer'])
best_acc = checkpoint['best_acc']
best_step = checkpoint['best_step']
global_step = checkpoint['global_step']
python的变量也可以用类似这样的方式保存和加载,不得不说pytorch真的是很方便啊。