【pytorch】Pytorch训练好的模型转torchscript

TorchScript,它是PyTorch模型(子类nn.Module)的中间表示,可以在高性能环境(例如C ++)中运行。利用torch.jit.trace()函数

import torch
import torchvision
model=get_model #获得你的模型
model.load_state_dict(torch.load('./best.pth', map_location='cpu'))#加载你训练好的模型权重
model  = model.eval()#验证
dummy_input = torch.randn(1,3,299,299)#固定的输入维度
traced_script_module = torch.jit.trace(model, dummy_input)#
traced_script_module.save('./tt.pt')#保存路径

ps:如果你训练的时候是模型和参数保存在一起了直接model=torch.load('.pth')就可以了,上面代码演示的权重单独保存的记载方式

猜你喜欢

转载自blog.csdn.net/qq_44992785/article/details/129346751