Convertir le modèle PPOCRv3 en pytorch

Préface

La version PaddleOCRv3 est sortie il y a quelque temps et les modèles de détection et de reconnaissance ont été mis à jour. Les performances ont été grandement améliorées. Conformément au principe de la prostitution quand vous le pouvez, j'ai commencé à me prostituer dès le premier jour de sa sortie. Bien que le les performances du nouveau modèle sont par rapport au précédent, une grande amélioration, mais à première vue, la structure du modèle est beaucoup plus compliquée et son déploiement est beaucoup plus difficile. À ce stade, la conversion du framework paddle vers d'autres Les frameworks de déploiement ne peuvent être obtenus qu'en convertissant paddle2onnx puis vers d'autres frameworks, je prévois donc de sortir du piège et de fournir import paddle en tant que version Torch du modèle : transférer le poids du modèle du framework paddle vers pytorch pour offrir plus de choix pour le plan de déploiement. Après être passé au framework pytorch, vous pouvez passer à d'autres méthodes de déploiement à partir de pytorch. Prenons l'exemple précédent : utilisez pnnx pour pytorch Model vers ncnn model .

Comparaison avec les performances du modèle précédent :
Insérer la description de l'image ici
L'implémentation du code de ce projet est basée sur :

1. pagaie2torche

Parlons d'abord du principe de conversion. Parce que paddlepaddle et pytorch sont tous deux des frameworks dynamiques, la conversion est relativement simple. Pour que le modèle paddle soit converti, il suffit d'utiliser torch pour reconstruire la même structure de modèle de réseau, puis de retirer le poids de pagaie un par un. Les valeurs correspondantes sont attribuées à chaque couche. Il semble que le processus soit relativement simple, mais après tout, ce sont des frameworks différents, et certaines implémentations d'OP sont également différentes, il est donc inévitable qu'il y ait de nombreux pièges.

Avant la conversion, regardons d'abord quels modules PaddleOCRV3 a mis à jour par rapport à la version précédente du modèle :
Le premier est le modèle de détection :

Module de détection :

  1. LK-PAN : structure PAN à grand champ récepteur
  2. DML : Stratégie d'apprentissage mutuel modèle pour l'enseignant
  3. RSE-FPN : structure FPN du mécanisme d'attention résiduelle

Module d'identification :

  • SVTR_LCNet : réseau léger de reconnaissance de texte
  • GTC : l'attention guide la stratégie de formation du CTC
  • TextConAug : stratégie d'augmentation des données pour l'extraction d'informations contextuelles sur les textes
  • TextRotNet : modèle pré-entraîné auto-supervisé
  • UDML : stratégie d'apprentissage mutuel fédéré
  • UIM : solution d'exploration de données non labellisée

Pour plus de détails, veuillez consulter le rapport technique officiel de PPOCRV3 . Ici, nous devons uniquement prêter attention aux modules auxquels nous devons prêter attention pendant notre processus de conversion.

2. Conversion du modèle de détection

Le premier est le module de détection. Le module de détection comporte trois parties à mettre à jour. Nous devons seulement nous concentrer sur RSE-FPN, car les deux premières sont des optimisations du modèle enseignant par apprentissage par distillation pendant le processus de formation.

RSE-FPN (Residual Squeeze-and-Excitation FPN), comme le montre la figure ci-dessous, introduit la structure résiduelle et la structure d'attention du canal, remplace la couche convolutive du FPN par la couche RSEConv de la structure d'attention du canal et améliore encore la représentation de la carte des caractéristiques. Étant donné que le nombre de canaux FPN dans le modèle de détection de PP-OCRv2 est très faible, seulement 96, si SEblock est directement utilisé pour remplacer la convolution dans FPN, les caractéristiques de certains canaux seront supprimées et la précision diminuera. L'introduction d'une structure résiduelle dans RSEConv atténuera les problèmes ci-dessus et améliorera l'effet de détection de texte. Mettez à jour davantage la structure FPN du modèle étudiant CML dans PP-OCRv2 vers RSE-FPN, et la moyenne du modèle étudiant peut être encore améliorée de 84,3 % à 85,4 % : Implémentation du code pytorch RSE-FPN
Insérer la description de l'image ici
:

class RSELayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):
        super(RSELayer, self).__init__()
        self.out_channels = out_channels
        self.in_conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=self.out_channels,
            kernel_size=kernel_size,
            padding=int(kernel_size // 2),
            bias=False)
        self.se_block = SEBlock(self.out_channels,self.out_channels)
        self.shortcut = shortcut

    def forward(self, ins):
        x = self.in_conv(ins)
        if self.shortcut:
            out = x + self.se_block(x)
        else:
            out = self.se_block(x)
        return out


class RSEFPN(nn.Module):
    def __init__(self, in_channels, out_channels=256, shortcut=True, **kwargs):
        super(RSEFPN, self).__init__()
        self.out_channels = out_channels
        self.ins_conv = nn.ModuleList()
        self.inp_conv = nn.ModuleList()

        for i in range(len(in_channels)):
            self.ins_conv.append(
                RSELayer(
                    in_channels[i],
                    out_channels,
                    kernel_size=1,
                    shortcut=shortcut))
            self.inp_conv.append(
                RSELayer(
                    out_channels,
                    out_channels // 4,
                    kernel_size=3,
                    shortcut=shortcut))

    def _upsample_add(self, x, y):
        return F.interpolate(x, scale_factor=2) + y

    def _upsample_cat(self, p2, p3, p4, p5):
        p3 = F.interpolate(p3, scale_factor=2)
        p4 = F.interpolate(p4, scale_factor=4)
        p5 = F.interpolate(p5, scale_factor=8)
        return torch.cat([p5, p4, p3, p2], dim=1)

    def forward(self, x):
        c2, c3, c4, c5 = x

        in5 = self.ins_conv[3](c5)
        in4 = self.ins_conv[2](c4)
        in3 = self.ins_conv[1](c3)
        in2 = self.ins_conv[0](c2)

        out4 = self._upsample_add(in5, in4)
        out3 = self._upsample_add(out4, in3)
        out2 = self._upsample_add(out3, in2)

        p5 = self.inp_conv[3](in5)
        p4 = self.inp_conv[2](out4)
        p3 = self.inp_conv[1](out3)
        p2 = self.inp_conv[0](out2)

        x = self._upsample_cat(p2, p3, p4, p5)
        return x

Le réseau complet est divisé en trois parties : Backbone (MobileNetV3), Neck (RSEFPN) et Head (DBHead). Avec l'aide du projet PytorchOCR , ces trois parties sont implémentées séparément, puis le réseau est construit.

from torch import nn
from det.DetMobilenetV3 import MobileNetV3
from det.DB_fpn import DB_fpn,RSEFPN,LKPAN
from det.DetDbHead import DBHead

backbone_dict = {
    
    'MobileNetV3': MobileNetV3}
neck_dict = {
    
    'DB_fpn': DB_fpn,'RSEFPN':RSEFPN,'LKPAN':LKPAN}
head_dict = {
    
    'DBHead': DBHead}

class DetModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert 'in_channels' in config, 'in_channels must in model config'
        backbone_type = config.backbone.pop('type')
        assert backbone_type in backbone_dict, f'backbone.type must in {
      
      backbone_dict}'
        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)

        neck_type = config.neck.pop('type')
        assert neck_type in neck_dict, f'neck.type must in {
      
      neck_dict}'
        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)

        head_type = config.head.pop('type')
        assert head_type in head_dict, f'head.type must in {
      
      head_dict}'
        self.head = head_dict[head_type](self.neck.out_channels, **config.head)

        self.name = f'DetModel_{
      
      backbone_type}_{
      
      neck_type}_{
      
      head_type}'

    def load_3rd_state_dict(self, _3rd_name, _state):
        self.backbone.load_3rd_state_dict(_3rd_name, _state)
        self.neck.load_3rd_state_dict(_3rd_name, _state)
        self.head.load_3rd_state_dict(_3rd_name, _state)

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x

if __name__=="__main__":
    db_config = AttrDict(
        in_channels=3,
        backbone=AttrDict(type='MobileNetV3', model_name='large',scale=0.5,pretrained=True),
        neck=AttrDict(type='RSEFPN', out_channels=96),
        head=AttrDict(type='DBHead')
    )

    model = DetModel(db_config)

Utilisez ensuite le modèle d'entraînement à la détection de texte de paddleOCRV3 (notez que vous ne pouvez utiliser que le modèle d'entraînement), supprimez les poids du modèle et les valeurs clés correspondantes, et initialisez-les respectivement dans le modèle de torche. Le code complet est lié à la fin de l'article.

def load_state(path,trModule_state):
    """
    记载paddlepaddle的参数
    :param path:
    :return:
    """
    if os.path.exists(path + '.pdopt'):
        # XXX another hack to ignore the optimizer state
        tmp = tempfile.mkdtemp()
        dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
        shutil.copy(path + '.pdparams', dst + '.pdparams')
        state = fluid.io.load_program_state(dst)
        shutil.rmtree(tmp)
    else:
        state = fluid.io.load_program_state(path)

    # for i, key in enumerate(state.keys()):
    #     print("{}  {} ".format(i, key))

    state_dict = {
    
    }
    for i, key in enumerate(state.keys()):
        if key =="StructuredToParameterName@@":
            continue
        state_dict[trModule_state[i]] = torch.from_numpy(state[key])

    return state_dict

3. Conversion du modèle d'identification

La conversion du modèle de reconnaissance est beaucoup plus compliquée que le modèle de détection.Le module de reconnaissance de PP-OCRv3 est optimisé sur la base de l'algorithme de reconnaissance de texte SVTR. SVTR n'utilise plus la structure RNN. En introduisant la structure Transformers, il peut exploiter plus efficacement les informations contextuelles des images de lignes de texte, améliorant ainsi les capacités de reconnaissance de texte. Parmi les nombreuses optimisations de reconnaissance ci-dessus, nous devons seulement nous concentrer sur la première optimisation : SVTR_LCNet et les autres sont Les techniques de formation utilisées dans le processus de formation n'ont pas besoin d'être utilisées dans le processus de conversion de modèle.

SVTR_LCNet est un réseau léger de reconnaissance de texte qui intègre le réseau SVTR basé sur Transformer et le réseau CNN léger PP-LCNet pour les tâches de reconnaissance de texte. Le réseau global est le suivant : En utilisant ce réseau, la vitesse de prédiction est meilleure que PP-OCRv2 La
Insérer la description de l'image ici
reconnaissance Le modèle de reconnaissance est de 20 %, mais comme la stratégie de distillation n’est pas utilisée, le modèle de reconnaissance est légèrement moins efficace. De plus, la hauteur de normalisation de l'image d'entrée est encore augmentée de 32 à 48, et la vitesse de prédiction est légèrement plus lente, mais l'effet de modèle est grandement amélioré et la précision de reconnaissance atteint 73,98 % (+2,08 %), ce qui est proche à l'effet de modèle de reconnaissance de PP-OCRv2 en utilisant la stratégie de distillation.Processus expérimental d'ablation :
Insérer la description de l'image ici

De même, le modèle de réseau de torche est construit sur la base de la structure du réseau de reconnaissance de la palette. Le modèle est divisé en trois parties : Backbone (LCNet), Encoder (SVTR Transformers) et Head (MultiHead). La partie Encoder utilise l'encodage de la structure Transformers de SVTR :

class EncoderWithSVTR(nn.Module):
    def __init__(
            self,
            in_channels,
            dims=64,  # XS
            depth=2,
            hidden_dims=120,
            use_guide=False,
            num_heads=8,
            qkv_bias=True,
            mlp_ratio=2.0,
            drop_rate=0.1,
            attn_drop_rate=0.1,
            drop_path=0.,
            qk_scale=None):
        super(EncoderWithSVTR, self).__init__()
        self.depth = depth
        self.use_guide = use_guide
        self.conv1 = ConvBNLayer(
            in_channels, in_channels // 8, padding=1)
        self.conv2 = ConvBNLayer(
            in_channels // 8, hidden_dims, kernel_size=1)

        self.svtr_block = nn.ModuleList([
            Block(
                dim=hidden_dims,
                num_heads=num_heads,
                mixer='Global',
                HW=None,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                act_layer="Swish",
                attn_drop=attn_drop_rate,
                drop_path=drop_path,
                norm_layer='nn.LayerNorm',
                epsilon=1e-05,
                prenorm=False) for i in range(depth)
        ])
        self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
        self.conv3 = ConvBNLayer(
            hidden_dims, in_channels, kernel_size=1)
        # last conv-nxn, the input is concat of input tensor and conv3 output tensor
        self.conv4 = ConvBNLayer(
            2 * in_channels, in_channels // 8, padding=1)

        self.conv1x1 = ConvBNLayer(
            in_channels // 8, dims, kernel_size=1)
        self.out_channels = dims
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    def forward(self, x):
        # for use guide
        if self.use_guide:
            z = x.clone()
            z.stop_gradient = True
        else:
            z = x
        # for short cut
        h = z
        # reduce dim
        z = self.conv1(z)
        z = self.conv2(z)
        # SVTR global block
        B, C, H, W = z.shape
        z = z.flatten(2).permute([0, 2, 1])
        for blk in self.svtr_block:
            z = blk(z)
        z = self.norm(z)
        # last stage
        z = z.reshape([-1, H, W, C]).permute([0, 3, 1, 2])
        z = self.conv3(z)
        z = torch.cat((h, z), dim=1)
        z = self.conv1x1(self.conv4(z))
        return z

La partie Head est multi-têtes, mais seul CTCHead est réellement utilisé lors de l'inférence, et la SARHead pendant la formation est supprimée, cette partie n'a donc pas besoin d'être ajoutée lors de la construction du réseau.

class MultiHead(nn.Module):
    def __init__(self, in_channels, **kwargs):
        super().__init__()
        self.out_c = kwargs.get('n_class')
        self.head_list = kwargs.get('head_list')
        self.gtc_head = 'sar'
        # assert len(self.head_list) >= 2
        for idx, head_name in enumerate(self.head_list):
            # name = list(head_name)[0]
            name = head_name
            if name == 'SARHead':
                # sar head
                sar_args = self.head_list[name]
                self.sar_head = eval(name)(in_channels=in_channels, out_channels=self.out_c, **sar_args)
            if name == 'CTC':
                # ctc neck
                self.encoder_reshape = Im2Seq(in_channels)
                neck_args = self.head_list[name]['Neck']
                encoder_type = neck_args.pop('name')
                self.encoder = encoder_type
                self.ctc_encoder = SequenceEncoder(in_channels=in_channels,encoder_type=encoder_type, **neck_args)
                # ctc head
                head_args = self.head_list[name]
                self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels,n_class=self.out_c, **head_args)
            else:
                raise NotImplementedError(
                    '{} is not supported in MultiHead yet'.format(name))

    def forward(self, x, targets=None):
        ctc_encoder = self.ctc_encoder(x)
        ctc_out = self.ctc_head(ctc_encoder, targets)
        head_out = dict()
        head_out['ctc'] = ctc_out
        head_out['ctc_neck'] = ctc_encoder
        return ctc_out                          # infer   不经过SAR直接返回
        
        # # eval mode
        # print(not self.training)
        # if not self.training:                 # training
        #     return ctc_out
        # if self.gtc_head == 'sar':
        #     sar_out = self.sar_head(x, targets[1:])
        #     head_out['sar'] = sar_out
        #     return head_out
        # else:
        #     return head_out

Construction complète du réseau :

from torch import nn

from rec.RNN import SequenceEncoder, Im2Seq,Im2Im
from rec.RecSVTR import SVTRNet
from rec.RecMv1_enhance import MobileNetV1Enhance

from rec.RecCTCHead import CTC,MultiHead

backbone_dict = {
    
    "SVTR":SVTRNet,"MobileNetV1Enhance":MobileNetV1Enhance}
neck_dict = {
    
    'PPaddleRNN': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
head_dict = {
    
    'CTC': CTC,'Multi':MultiHead}


class RecModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert 'in_channels' in config, 'in_channels must in model config'
        backbone_type = config.backbone.pop('type')
        assert backbone_type in backbone_dict, f'backbone.type must in {
      
      backbone_dict}'
        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)

        neck_type = config.neck.pop('type')
        assert neck_type in neck_dict, f'neck.type must in {
      
      neck_dict}'
        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)

        head_type = config.head.pop('type')
        assert head_type in head_dict, f'head.type must in {
      
      head_dict}'
        self.head = head_dict[head_type](self.neck.out_channels, **config.head)

        self.name = f'RecModel_{
      
      backbone_type}_{
      
      neck_type}_{
      
      head_type}'

    def load_3rd_state_dict(self, _3rd_name, _state):
        self.backbone.load_3rd_state_dict(_3rd_name, _state)
        self.neck.load_3rd_state_dict(_3rd_name, _state)
        self.head.load_3rd_state_dict(_3rd_name, _state)

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x

if __name__=="__main__":

    rec_config = AttrDict(
        in_channels=3,
        backbone=AttrDict(type='MobileNetV1Enhance', scale=0.5,last_conv_stride=[1,2],last_pool_type='avg'),
        neck=AttrDict(type='None'),
   head=AttrDict(type='Multi',head_list=AttrDict(CTC=AttrDict(Neck=AttrDict(name="svtr",dims=64,depth=2,hidden_dims=120,use_guide=True)),
                                                       # SARHead=AttrDict(enc_dim=512,max_text_length=70)
                                                      ),
                      n_class=6625)
    )

    model = RecModel(rec_config)

De même, chargez le modèle d'entraînement à la reconnaissance de paddleocrv3, retirez la valeur clé correspondant au poids et initialisez-la dans le modèle de torche.Cependant, ce qu'il faut noter ici est le problème de forme du poids de la couche de liaison complète dans la pagaie et le couche de liaison complète dans Torch. Lorsque la couche de liaison est affectée à la couche de liaison complète de Torch, les poids doivent être transposés (transpose() :

def load_state(path,trModule_state):
    """
    记载paddlepaddle的参数
    :param path:
    :return:
    """
    if os.path.exists(path + '.pdopt'):
        # XXX another hack to ignore the optimizer state
        tmp = tempfile.mkdtemp()
        dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
        shutil.copy(path + '.pdparams', dst + '.pdparams')
        state = fluid.io.load_program_state(dst)
        shutil.rmtree(tmp)
    else:
        state = fluid.io.load_program_state(path)

    # for i, key in enumerate(state.keys()):
    #     print("{}  {} ".format(i, key))
    keys = ["head.ctc_encoder.encoder.svtr_block.0.mixer.qkv.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mixer.proj.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mlp.fc1.weight",
            "head.ctc_encoder.encoder.svtr_block.0.mlp.fc2.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mixer.qkv.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mixer.proj.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mlp.fc1.weight",
            "head.ctc_encoder.encoder.svtr_block.1.mlp.fc2.weight",
            "head.ctc_head.fc.weight",
            ]

    state_dict = {
    
    }
    for i, key in enumerate(state.keys()):
        if key =="StructuredToParameterName@@":
            continue
        if i > 238:
            j = i-239
            if j <= 195:
                if trModule_state[j] in keys:
                    state_dict[trModule_state[j]] = torch.from_numpy(state[key]).transpose(0,1)
                else:
                    state_dict[trModule_state[j]] = torch.from_numpy(state[key])

    return state_dict

Lien du modèle de formation PaddleOCR PaddleOCR :
Insérer la description de l'image ici
Le code complet a été lancé sur github, bienvenue pour en tirer des leçons.

paddle2torch_PPOCRv3

Je suppose que tu aimes

Origine blog.csdn.net/qq_39056987/article/details/124921515
conseillé
Classement