mmdetection-yolox导出ncnn模型(1)

注:

1.本文基于mmdetection-2.25.1。为啥不用最新版本?3.0的还没试,2.28的有差不多的问题,老板要求用这个版本,所以先用这个演示一遍全流程。

2.本文直接用mmdetection里面提供的一个“不建议使用”的脚本来导出onnx格式(ncnn先别急),即tools/deployment/pytorch2onnx.py。为啥不用mmdeploy?一个是也不见得行,另外老板暂时不让用~~

3.对了,还有一个问题,为啥要用mmdetection的yolox,直接用旷世的yolox不香吗?别问,问就是老板要求的~~,话说回来,也不是说mmdetection不好,他们集成了如此之多的算法,难免有不完善的地方,更需要我们一起研究,共同完善(但是我也没完全研究好,先不提PR了,你们谁要是想去提,直接提好了,不用问我)

接下来直接进入正题。

相关python包版本如下

把mmdetection上的yolox_s模型搞过来放checkpoints文件夹里,就可以直接试一下 :

python tools/deployment/pytorch2onnx.py configs/yolox/yolox_s_8x8_300e_coco.py checkpoints/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth --input-img demo/demo.jpg --output-file checkpoints/yolox_s.onnx  --show --shape 640 640 

问题1

会先遇到第一个错:AssertionError: `norm_config` should only have one

扫描二维码关注公众号,回复: 14961232 查看本文章

这个问题是因为mmdetection的yolox的配置文件(如configs/yolox/yolox_s_8x8_300e_coco.py)中就没有配置对输入归一化,至于为神马会这样,有待研究(比如跟原生的yolox对比一下),另外为啥不配置归一化还可以work?难道是BN层缓解了梯度爆炸问题?也不是没有可能~~

既然没有归一化,你这导出脚本还强制检查个啥,你这不自相矛盾吗,把它的强制检查去掉即可(左边原代码,右边修改的代码,该文件修改了多处,完整代码下面统一放出。

 

然后再执行一遍,这没这个错啦。

问题2 

 先插个问题2,如果按我上面的pytorch环境,并不会报这个错,但如果是高版本的pytorch,就有可能报这个错,比如

在改完问题1后,就会报问题2的错,并且我试了一下用mmdetection2.28,仍然是这样。 

 报错:RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: numpy.ndarray

这个错就是说你在导出的路径中传了numpy的数组,这是不允许的。其实就是下面这个东西,存放了一些图片信息,包括原始图片

但是在我们最终修改完的代码中,它是没什么用的。至于别的算法用不用,我先不管它了,直接把它改为空字典,搞定。

另外再插一句,如果不改这里的代码,而是在命令中直接加一个 --skip-postprocess ,即:

python tools/deployment/pytorch2onnx.py configs/yolox/yolox_s_8x8_300e_coco.py checkpoints/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth --input-img demo/demo.jpg --output-file checkpoints/yolox_s.onnx  --show --shape 640 640 --skip-postprocess 

就可以直接导出了,而且连问题3都没有,但是这其实就是把后处理全去掉了,我们其实也不完全算是不要后处理啦,而且按这个导出脚本的逻辑,后面的simplfiy啥的都不会再执行了,所以就先不按这个来了搞了。(后面会具体结合mmdeploy的使用方法再考虑考虑全面一点的方案)

问题3:

问题3才是今天主要要解决的问题

 接下来出现了新的报错:AttributeError: 'YOLOXHead' object has no attribute 'bbox_coder'

继承关系梳理 

我们最终使用的模型类就是mmdet/models/detectors/yolox.py里的YOLOX类,先梳理一下继承关系,这样才好确定到底是什么问题,该怎么改。下图就是YOLOX类的继承结构,实线箭头表示父类与子类,虚线箭头表示类与成员。

 再来一张图,右边都是对应类的方法名(同一行),相同方法名按列对齐,这样就知道哪个类没有哪些方法了。然后就是从最下面的torch.onnx.export开始,先后调用了什么方法。可以看到,先调了YOLOXHEAD的forward方法,然后又调BaseDenseHead的onnx_export方法,这个方法想去调self.bbox_coder.decode(),但是YOLOXHEAD并没有这个属性。所以就报错了。

 接下来我们简单过一下,YOLOXHEAD的forward方法里面都干了啥,以及相关的函数有什么功能。

forward:

forward会调forward_single,forward_single返回的就是解耦后的分类得分、边框、置信度(yolov5的这些东西在85个通道里(按coco80分类算的话),而yolox则解耦开了),然后forward里用的muiti_apply就是分别对3个输出层来调forward_single,返回的东西会放在3个list里面,这个可以自己调试一下就看明白了。

 get_bboxes,分析见里面的中文注释 

 

 

 看到这里就明白了,如果调了get_bboxes,就得会到最终的输出,即经过了置信度过滤,经过了nms的最终输出。

如果你只是想导出onnx格式的模型,那么在调完YOLOXHEAD的forward之后,接着去调get_bboxes就搞定了。但是我们最终想导出的是ncnn格式的模型,get_bboxes里面的很多操作涉及的算子在ncnn里面都是不支持的,比如生成3个输出层对应的网格,比如nms,所以说我们其实不能在这里调get_bboxes,我们只需要forward的3个输出层的信息,顶多做一下简单的合并,得到的就是我们在ncnn模型中想到的输出,然后复杂的后处理部分由C++代码来实现。

c++代码

那问题来了,C++代码需要我们自己写吗,如果没有现成的,自然是自己写。但是好消息是,这次真有现成的!旷视科技他们的原生的YOLOX本来就是支持导出ncnn格式的,并且他们还提供了完整的demo,包括安卓的,C++的后处理demo,拿来即用,简直完美~~

github地址:https://github.com/Megvii-BaseDetection/YOLOX

 这C++代码就不仔细看了,就简单看一下边框是怎么计算的,如下图,feat_ptr就是ncnn输出的边框值,然后在C++里面作用到网格上,包括宽高的指数计算也是在C++里面完成的。

 如果还是不确定的话,可以看一下旷视YOLOX的python代码,看看它导出onnx的输出是什么样的

用Netron看一下他们的onnx模型的输出,最终是合并起来的,注意3个输出层的顺序,是先80再40再20,再注意85个通道的顺序,是先边框、再置信度、再分类。这个顺序很重要,因为他们的C++代码必然是按照这个顺序来解析的。

 修改python代码

接下来我们就可以修改mmdetection的代码了:

在YOLOXHEAD(mmdet/models/dense_heads/yolox_head.py)中实现一个onnx_export方法

    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
    def onnx_export(self,
                    cls_scores,
                    bbox_preds,
                    objectnesses=None,
                    img_metas=None,
                    with_nms=True):
        assert len(cls_scores) == len(bbox_preds) == len(objectnesses)

        num_imgs = len(img_metas)

        # flatten cls_scores, bbox_preds and objectness
        flatten_cls_scores = [
            cls_score.reshape(num_imgs, self.cls_out_channels, -1)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.reshape(num_imgs, 4, -1)
            for bbox_pred in bbox_preds
        ]
        flatten_objectness = [
            objectness.reshape(num_imgs, 1, -1)
            for objectness in objectnesses
        ]

        flatten_cls_scores = torch.cat(flatten_cls_scores, dim=2).sigmoid()
        flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=2)
        flatten_objectness = torch.cat(flatten_objectness, dim=2).sigmoid()

        return torch.cat((flatten_bbox_preds, flatten_objectness, flatten_cls_scores), dim=1).permute(0, 2, 1)

 返回的就是按上面的顺序合并的3个输出层的输出

再在YOLOX(mmdet/models/detectors/yolox.py)里面实现一个onnx_export方法

    def onnx_export(self, img, img_metas, with_nms=True):
        """Test function without test time augmentation.

        Args:
            img (torch.Tensor): input images.
            img_metas (list[dict]): List of image information.

        Returns:
            tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
                and class labels of shape [N, num_det].
        """
        x = self.extract_feat(img)
        outs = self.bbox_head(x)
        # get origin input shape to support onnx dynamic shape

        # get shape as tensor
        img_shape = torch._shape_as_tensor(img)[2:]
        img_metas[0]['img_shape_for_onnx'] = img_shape
        # get pad input shape to support onnx dynamic shape for exporting
        # `CornerNet` and `CentripetalNet`, which 'pad_shape' is used
        # for inference
        img_metas[0]['pad_shape_for_onnx'] = img_shape

        return self.bbox_head.onnx_export(
            *outs, img_metas, with_nms=with_nms)

更新调用顺序图如下:

 对了,再提供一下之前修改的tools/deployment/pytorch2onnx.py

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import warnings
from functools import partial

import numpy as np
import onnx
import torch
from mmcv import Config, DictAction

from mmdet.core.export import build_model_from_cfg, preprocess_example_input
from mmdet.core.export.model_wrappers import ONNXRuntimeDetector


def pytorch2onnx(model,
                 input_img,
                 input_shape,
                 normalize_cfg,
                 opset_version=11,
                 show=False,
                 output_file='tmp.onnx',
                 verify=False,
                 test_img=None,
                 do_simplify=False,
                 dynamic_export=None,
                 skip_postprocess=False):

    input_config = {
        'input_shape': input_shape,
        'input_path': input_img
    }

    if normalize_cfg:
        input_config['normalize_cfg'] = normalize_cfg

    # prepare input
    one_img, one_meta = preprocess_example_input(input_config)
    img_list, img_meta_list = [one_img], [[{}]]

    if skip_postprocess:
        warnings.warn('Not all models support export onnx without post '
                      'process, especially two stage detectors!')
        model.forward = model.forward_dummy
        torch.onnx.export(
            model,
            one_img,
            output_file,
            input_names=['input'],
            export_params=True,
            keep_initializers_as_inputs=True,
            do_constant_folding=True,
            verbose=show,
            opset_version=opset_version)

        print(f'Successfully exported ONNX model without '
              f'post process: {output_file}')
        return

    # replace original forward function
    origin_forward = model.forward
    model.forward = partial(
        model.forward,
        img_metas=img_meta_list,
        return_loss=False,
        rescale=False)

    output_names = ['output']
    if model.with_mask:
        output_names.append('masks')
    input_name = 'input'
    dynamic_axes = None
    if dynamic_export:
        dynamic_axes = {
            input_name: {
                0: 'batch',
                2: 'height',
                3: 'width'
            },
            'dets': {
                0: 'batch',
                1: 'num_dets',
            },
            'labels': {
                0: 'batch',
                1: 'num_dets',
            },
        }
        if model.with_mask:
            dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}

    torch.onnx.export(
        model,
        img_list,
        output_file,
        input_names=[input_name],
        output_names=output_names,
        export_params=True,
        keep_initializers_as_inputs=True,
        do_constant_folding=True,
        verbose=show,
        opset_version=opset_version,
        dynamic_axes=dynamic_axes)

    model.forward = origin_forward

    # get the custom op path
    ort_custom_op_path = ''
    try:
        from mmcv.ops import get_onnxruntime_op_path
        ort_custom_op_path = get_onnxruntime_op_path()
    except (ImportError, ModuleNotFoundError):
        warnings.warn('If input model has custom op from mmcv, \
            you may have to build mmcv with ONNXRuntime from source.')

    if do_simplify:
        import onnxsim

        from mmdet import digit_version

        min_required_version = '0.3.0'
        assert digit_version(onnxsim.__version__) >= digit_version(
            min_required_version
        ), f'Requires to install onnx-simplify>={min_required_version}'

        input_dic = {'input': img_list[0].detach().cpu().numpy()}
        model_opt, check_ok = onnxsim.simplify(
            output_file,
            input_data=input_dic,
            custom_lib=ort_custom_op_path,
            dynamic_input_shape=dynamic_export)
        if check_ok:
            onnx.save(model_opt, output_file)
            print(f'Successfully simplified ONNX model: {output_file}')
        else:
            warnings.warn('Failed to simplify ONNX model.')
    print(f'Successfully exported ONNX model: {output_file}')

    if verify:
        # check by onnx
        onnx_model = onnx.load(output_file)
        onnx.checker.check_model(onnx_model)

        # wrap onnx model
        onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0)
        if dynamic_export:
            # scale up to test dynamic shape
            h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
            h, w = min(1344, h), min(1344, w)
            input_config['input_shape'] = (1, 3, h, w)

        if test_img is None:
            input_config['input_path'] = input_img

        # prepare input once again
        one_img, one_meta = preprocess_example_input(input_config)
        img_list, img_meta_list = [one_img], [[one_meta]]

        # get pytorch output
        with torch.no_grad():
            pytorch_results = model(
                img_list,
                img_metas=img_meta_list,
                return_loss=False,
                rescale=True)[0]

        img_list = [_.cuda().contiguous() for _ in img_list]
        if dynamic_export:
            img_list = img_list + [_.flip(-1).contiguous() for _ in img_list]
            img_meta_list = img_meta_list * 2
        # get onnx output
        onnx_results = onnx_model(
            img_list, img_metas=img_meta_list, return_loss=False)[0]
        # visualize predictions
        score_thr = 0.3
        if show:
            out_file_ort, out_file_pt = None, None
        else:
            out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png'

        show_img = one_meta['show_img']
        model.show_result(
            show_img,
            pytorch_results,
            score_thr=score_thr,
            show=True,
            win_name='PyTorch',
            out_file=out_file_pt)
        onnx_model.show_result(
            show_img,
            onnx_results,
            score_thr=score_thr,
            show=True,
            win_name='ONNXRuntime',
            out_file=out_file_ort)

        # compare a part of result
        if model.with_mask:
            compare_pairs = list(zip(onnx_results, pytorch_results))
        else:
            compare_pairs = [(onnx_results, pytorch_results)]
        err_msg = 'The numerical values are different between Pytorch' + \
                  ' and ONNX, but it does not necessarily mean the' + \
                  ' exported ONNX model is problematic.'
        # check the numerical value
        for onnx_res, pytorch_res in compare_pairs:
            for o_res, p_res in zip(onnx_res, pytorch_res):
                np.testing.assert_allclose(
                    o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg)
        print('The numerical values are the same between Pytorch and ONNX')


def parse_normalize_cfg(test_pipeline):
    transforms = None
    for pipeline in test_pipeline:
        if 'transforms' in pipeline:
            transforms = pipeline['transforms']
            break
    assert transforms is not None, 'Failed to find `transforms`'
    norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
    assert len(norm_config_li) <= 1, '`norm_config` should only have one'

    # yolox has no Normalize, see configs/yolox/yolox_s_8x8_300e_coco.py,
    # https://github.com/open-mmlab/mmdetection/pull/6443
    norm_config = norm_config_li[0] if norm_config_li else None
    return norm_config


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert MMDetection models to ONNX')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--input-img', type=str, help='Images for input')
    parser.add_argument(
        '--show',
        action='store_true',
        help='Show onnx graph and detection outputs')
    parser.add_argument('--output-file', type=str, default='tmp.onnx')
    parser.add_argument('--opset-version', type=int, default=11)
    parser.add_argument(
        '--test-img', type=str, default=None, help='Images for test')
    parser.add_argument(
        '--dataset',
        type=str,
        default='coco',
        help='Dataset name. This argument is deprecated and will be removed \
        in future releases.')
    parser.add_argument(
        '--verify',
        action='store_true',
        help='verify the onnx model output against pytorch output')
    parser.add_argument(
        '--simplify',
        action='store_true',
        help='Whether to simplify onnx model.')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[800, 1216],
        help='input image size')
    parser.add_argument(
        '--mean',
        type=float,
        nargs='+',
        default=[123.675, 116.28, 103.53],
        help='mean value used for preprocess input data.This argument \
        is deprecated and will be removed in future releases.')
    parser.add_argument(
        '--std',
        type=float,
        nargs='+',
        default=[58.395, 57.12, 57.375],
        help='variance value used for preprocess input data. '
        'This argument is deprecated and will be removed in future releases.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='Override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--dynamic-export',
        action='store_true',
        help='Whether to export onnx with dynamic axis.')
    parser.add_argument(
        '--skip-postprocess',
        action='store_true',
        help='Whether to export model without post process. Experimental '
        'option. We do not guarantee the correctness of the exported '
        'model.')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    warnings.warn('Arguments like `--mean`, `--std`, `--dataset` would be \
        parsed directly from config file and are deprecated and \
        will be removed in future releases.')

    assert args.opset_version == 11, 'MMDet only support opset 11 now'

    try:
        from mmcv.onnx.symbolic import register_extra_symbolics
    except ModuleNotFoundError:
        raise NotImplementedError('please update mmcv to version>=v1.0.4')
    register_extra_symbolics(args.opset_version)

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    if args.shape is None:
        img_scale = cfg.test_pipeline[1]['img_scale']
        input_shape = (1, 3, img_scale[1], img_scale[0])
    elif len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (1, 3) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    # build the model and load checkpoint
    model = build_model_from_cfg(args.config, args.checkpoint,
                                 args.cfg_options)

    if not args.input_img:
        args.input_img = osp.join(osp.dirname(__file__), '../../demo/demo.jpg')

    normalize_cfg = parse_normalize_cfg(cfg.test_pipeline)

    # convert model to onnx file
    pytorch2onnx(
        model,
        args.input_img,
        input_shape,
        normalize_cfg,
        opset_version=args.opset_version,
        show=args.show,
        output_file=args.output_file,
        verify=args.verify,
        test_img=args.test_img,
        do_simplify=args.simplify,
        dynamic_export=args.dynamic_export,
        skip_postprocess=args.skip_postprocess)

    # Following strings of text style are from colorama package
    bright_style, reset_style = '\x1b[1m', '\x1b[0m'
    red_text, blue_text = '\x1b[31m', '\x1b[34m'
    white_background = '\x1b[107m'

    msg = white_background + bright_style + red_text
    msg += 'DeprecationWarning: This tool will be deprecated in future. '
    msg += blue_text + 'Welcome to use the unified model deployment toolbox '
    msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
    msg += reset_style
    warnings.warn(msg)

重新导出onnx

这次不报错了(这个warning不管它,它只是不建议用这个脚本~~)

 看一下我们的输出结构,跟旷视的一样!转为ncnn之后是可以直接用他们的C++程序来调的。

 此处省略的一万字请见下回分解。。。

省略的就是onnx再怎么转ncnn啦,里面包括要把focus层改为自定义层由C++代码完成,这些其实在这个readme里面都说到了

先放预告:

猜你喜欢

转载自blog.csdn.net/ogebgvictor/article/details/130183744