1.onnx
1.1导出onnx
使用pytorch2onnx.py
注意需要修改代码中 opset_version=14,
执行:
DATAPATH="/data/LLM/clip-data/"
checkpoint_path=${DATAPATH}/pretrained_weights/clip_cn_vit-b-16.pt
python pytorch_to_onnx.py --model-arch ViT-B-16 --pytorch-ckpt-path ${checkpoint_path} --save-onnx-path ${DATAPATH}/deploy/vit-b-16 --convert-text --convert-vision
代码:
# -*- coding: utf-8 -*-
"""
This script converts PyTorch implemented Chinese-CLIP (text or vision) model to ONNX format for CPU/GPU deployment.
"""
import os
import argparse
from PIL import Image
import torch
import torch.onnx
from onnx import load_model, save_model
from onnxmltools.utils import convert_float_to_float16
import cn_clip.clip as clip
from cn_clip.clip.utils import _MODELS, _MODEL_INFO, _download, available_models, create_model, image_transform
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-arch",
required=True,
choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"],
help="Specify the architecture (model scale) of Chinese-CLIP model to be converted."
)
parser.add_argument(
"--pytorch-ckpt-path",
default=None,
type=str,
help="Path of the input PyTorch Chinese-CLIP checkpoint. Default to None which will automatically download the pretrained checkpoint."
)
parser.add_argument(
"--download-root",
default=None,
type=str,
help="If --pytorch-ckpt-path is None, official pretrained ckpt will be downloaded under --download-root directory and converted. Default to ~/cache/clip/ ."
)
parser.add_argument(
"--save-onnx-path",
required=True,
type=str,
help="Path (prefix) of the output converted ONNX Chinese-CLIP text or vision model."
)
parser.add_argument(
"--convert-text",
action="store_true",
help="Whether to convert the text encoder (text feature extractor) into ONNX."
)
parser.add_argument(
"--convert-vision",
action="store_true",
help="Whether to convert the vision encoder (vision feature extractor) into ONNX."
)
parser.add_argument(
"--context-length", type=int, default=52, help="The padded length of input text (include [CLS] & [SEP] tokens). Default to 52."
)
args = parser.parse_args()
return args
def packing_small_onnx_files(onnx_path):
# packing small files into an extra file
save_model(load_model(onnx_path),
onnx_path,
location="{}.extra_file".format(os.path.split(onnx_path)[1]),
save_as_external_data=True,
all_tensors_to_one_file=True,
size_threshold=1024,
convert_attribute=True)
# remove small files
onnx_dir = os.path.split(onnx_path)[0]
for key in checkpoint['state_dict']:
if key.startswith('module.visual'):
small_file_path = os.path.join(onnx_dir, key[7:])
if os.path.exists(small_file_path):
os.remove(small_file_path)
if key.startswith('visual'):
small_file_path = os.path.join(onnx_dir, key)
if os.path.exists(small_file_path):
os.remove(small_file_path)
os.system("rm -f {}".format(os.path.join(onnx_dir, "Constant_*_attr__value")))
if __name__ == '__main__':
args = parse_args()
# Log params.
print("Params:")
for name in sorted(vars(args)):
val = getattr(args, name)
print(f" {name}: {val}")
# prepare the PyTorch model weights
if os.path.isfile(args.pytorch_ckpt_path):
input_ckpt_path = args.pytorch_ckpt_path
elif args.model_arch in _MODELS:
input_ckpt_path = _download(_MODELS[args.model_arch], args.download_root or os.path.expanduser("./cache/clip"))
else:
raise RuntimeError(f"Model {args.model_arch} not found; available models =