Pytorch学习(四) --- 模型的保存和加载

Pytorch提供了两种方法进行模型的保存和加载。

第一种(推荐):
该方法值保存和加载模型的参数

# 保存
torch.save(the_model.state_dict(), PATH)
# 加载
# 定义模型
the_model = TheModelClass(*args, **kwargs)
# 加载模型
the_model.load_state_dict(torch.load(PATH))

例如:

import torch
import torchvision.models as models
# 创建模型
model = models.resnet101().cuda()
'''
训练过程
'''
# 保存训练后的模型
torch.save(model.state_dict(), './resnet101_test.pt'.)

第二种:
保存和加载整个模型。

# 保存
torch.save(the_model, PATH)
# 加载
the_model = torch.load(PATH)
原创文章 96 获赞 24 访问量 3万+

猜你喜欢

转载自blog.csdn.net/c2250645962/article/details/105263338