Convertir le modèle Pytorch en modèle Torchscript

TorchScript, qui est une représentation intermédiaire du modèle PyTorch (sous-classe nn.Module), peut être exécuté dans un environnement hautes performances (tel que C ++).

Il existe deux méthodes de conversion, l'une consiste à suivre la conversion et l'autre à convertir via des annotations.

1. Suivre la conversion

Elle est couramment utilisée pour suivre les conversions, mais cette méthode présente l'inconvénient que la taille d'entrée est fixe.

L'exemple de suivi donné sur le site officiel:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

Le tracé  ScriptModule peut maintenant être évalué de la même manière qu'un module PyTorch classique:

In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)

Les exemples les plus courants trouvés sur Internet sont le suivi des conversions.

 

2. Conversion de commentaires

Exemple de site Web officiel

Conversion en script Torch via annotation

Dans certaines circonstances, par exemple si votre modèle utilise des formes particulières de flux de contrôle, vous souhaiterez peut-être écrire votre modèle directement dans Torch Script et annoter votre modèle en conséquence. Par exemple, disons que vous avez le modèle de Pytorch vanille suivant:

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

Étant donné que la  forward méthode de ce module utilise un flux de contrôle qui dépend de l'entrée, il ne convient pas pour le traçage. Au lieu de cela, nous pouvons le convertir en fichier  ScriptModule. Afin de convertir le module en  ScriptModule, il faut compiler le module avec  torch.jit.script comme suit:

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

Si vous devez exclure certaines méthodes de votre  nn.Module parce qu'elles utilisent des fonctionnalités Python que TorchScript ne prend pas encore en charge, vous pouvez les annoter avec @torch.jit.ignore

my_module est une instance de  ScriptModule qui est prête pour la sérialisation.

Étape 2: sérialisation de votre module de script dans un fichier

Une fois que vous avez un  ScriptModule entre vos mains, que ce soit en traçant ou en annotant un modèle PyTorch, vous êtes prêt à le sérialiser dans un fichier. Plus tard, vous pourrez charger le module à partir de ce fichier en C ++ et l'exécuter sans aucune dépendance à Python. Supposons que nous souhaitons sérialiser le  ResNet18 modèle présenté précédemment dans l'exemple de traçage. Pour effectuer cette sérialisation, il suffit d'appeler  save  sur le module et de lui passer un nom de fichier:

traced_script_module.save("traced_resnet_model.pt")

Cela produira un  traced_resnet_model.pt fichier dans votre répertoire de travail. Si vous souhaitez également sérialiser  my_module, appelez  my_module.save("my_module_model.pt") Nous avons maintenant officiellement quitté le royaume de Python et sommes prêts à passer dans la sphère de C ++.

 

 

 

 

Référence: Convertir le modèle PyTorch au format TorchScript

Je suppose que tu aimes

Origine blog.csdn.net/juluwangriyue/article/details/108635280
conseillé
Classement