Pytorch模型的加载和保存

版权声明:如使用此博客内容,请经过作者同意,谢谢 https://blog.csdn.net/qq_40994943/article/details/85218682

1.保存整个模型的结构和参数,保存对象为model
torch.save(model,’./model.pth’)
保存对象的参数,保存的对象是模型的状态model.state_dict()
torch.save(model.state_dict(),’./model_state.pth)

2.加载模型结构和参数
load_model=torch.load('model.pth)
加载模型参数信息,需要先导入模型结构,然后通过model.load_state_dic(torch.load(‘model_state.pth’))导入

猜你喜欢

转载自blog.csdn.net/qq_40994943/article/details/85218682