Chinese-clip转换onnx和trt并进行推理

1.onnx

1.1导出onnx

使用pytorch2onnx.py  

注意需要修改代码中 opset_version=14,

执行:

DATAPATH="/data/LLM/clip-data/"

checkpoint_path=${DATAPATH}/pretrained_weights/clip_cn_vit-b-16.pt

python pytorch_to_onnx.py        --model-arch ViT-B-16        --pytorch-ckpt-path ${checkpoint_path}        --save-onnx-path ${DATAPATH}/deploy/vit-b-16        --convert-text --convert-vision

代码:

# -*- coding: utf-8 -*-
"""
This script converts PyTorch implemented Chinese-CLIP (text or vision) model to ONNX format for CPU/GPU deployment.
"""

import os
import argparse
from PIL import Image
import torch
import torch.onnx
from onnx import load_model, save_model
from onnxmltools.utils import convert_float_to_float16
import cn_clip.clip as clip
from cn_clip.clip.utils import _MODELS, _MODEL_INFO, _download, available_models, create_model, image_transform

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-arch", 
        required=True, 
        choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"],
        help="Specify the architecture (model scale) of Chinese-CLIP model to be converted."
    )
    parser.add_argument(
        "--pytorch-ckpt-path", 
        default=None, 
        type=str, 
        help="Path of the input PyTorch Chinese-CLIP checkpoint. Default to None which will automatically download the pretrained checkpoint."
    )
    parser.add_argument(
        "--download-root", 
        default=None, 
        type=str, 
        help="If --pytorch-ckpt-path is None, official pretrained ckpt will be downloaded under --download-root directory and converted. Default to ~/cache/clip/ ."
    )
    parser.add_argument(
        "--save-onnx-path", 
        required=True,
        type=str, 
        help="Path (prefix) of the output converted ONNX Chinese-CLIP text or vision model."
    )
    parser.add_argument(
        "--convert-text",
        action="store_true",
        help="Whether to convert the text encoder (text feature extractor) into ONNX."
    )
    parser.add_argument(
        "--convert-vision",
        action="store_true",
        help="Whether to convert the vision encoder (vision feature extractor) into ONNX."
    )
    parser.add_argument(
        "--context-length", type=int, default=52, help="The padded length of input text (include [CLS] & [SEP] tokens). Default to 52."
    )
    args = parser.parse_args()
    return args


def packing_small_onnx_files(onnx_path):
    # packing small files into an extra file
    save_model(load_model(onnx_path), 
            onnx_path, 
            location="{}.extra_file".format(os.path.split(onnx_path)[1]),
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            size_threshold=1024,
            convert_attribute=True)
    # remove small files
    onnx_dir = os.path.split(onnx_path)[0]
    for key in checkpoint['state_dict']:
        if key.startswith('module.visual'):
            small_file_path = os.path.join(onnx_dir, key[7:])
            if os.path.exists(small_file_path):
                os.remove(small_file_path)
        if key.startswith('visual'):
            small_file_path = os.path.join(onnx_dir, key)
            if os.path.exists(small_file_path):
                os.remove(small_file_path)                    
    os.system("rm -f {}".format(os.path.join(onnx_dir, "Constant_*_attr__value")))


if __name__ == '__main__':
    args = parse_args()

    # Log params.
    print("Params:")
    for name in sorted(vars(args)):
        val = getattr(args, name)
        print(f"  {name}: {val}")

    # prepare the PyTorch model weights
    if os.path.isfile(args.pytorch_ckpt_path):
        input_ckpt_path = args.pytorch_ckpt_path
    elif args.model_arch in _MODELS:
        input_ckpt_path = _download(_MODELS[args.model_arch], args.download_root or os.path.expanduser("./cache/clip"))
    else:
        raise RuntimeError(f"Model {args.model_arch} not found; available models =

猜你喜欢

转载自blog.csdn.net/u012374012/article/details/143369151