Pytorch模型转Torchscript模型

TorchScript,它是PyTorch模型(子类nn.Module)的中间表示,可以在高性能环境(例如C ++)中运行。

转换的方法有两种,一种是通过追踪转换另一种是通过注释转换。

1、追踪转换

常用的是追踪转换,但是这种方法有一个缺点,就是输入尺寸固定。

官网给出的追踪的例子:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

The traced ScriptModule can now be evaluated identically to a regular PyTorch module:

In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

网上找到最多的也是追踪转换的例子。

2、注释转换

官网例子

Converting to Torch Script via Annotation

Under certain circumstances, such as if your model employs particular forms of control flow, you may want to write your model in Torch Script directly and annotate your model accordingly. For example, say you have the following vanilla Pytorch model:

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

Because the forward method of this module uses control flow that is dependent on the input, it is not suitable for tracing. Instead, we can convert it to a ScriptModule. In order to convert the module to the ScriptModule, one needs to compile the module with torch.jit.script as follows:

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

If you need to exclude some methods in your nn.Module because they use Python features that TorchScript doesn’t support yet, you could annotate those with @torch.jit.ignore

my_module is an instance of ScriptModule that is ready for serialization.

Step 2: Serializing Your Script Module to a File

Once you have a ScriptModule in your hands, either from tracing or annotating a PyTorch model, you are ready to serialize it to a file. Later on, you’ll be able to load the module from this file in C++ and execute it without any dependency on Python. Say we want to serialize the ResNet18 model shown earlier in the tracing example. To perform this serialization, simply call save on the module and pass it a filename:

traced_script_module.save("traced_resnet_model.pt")

This will produce a traced_resnet_model.pt file in your working directory. If you also would like to serialize my_module, call my_module.save("my_module_model.pt") We have now officially left the realm of Python and are ready to cross over to the sphere of C++.

参考:PyTorch模型转换为TorchScript格式

猜你喜欢

转载自blog.csdn.net/juluwangriyue/article/details/108635280