MNN框架下的特征图格式问题

这两个星期在将一个手势关键点检测的Pytorch模型转化成MNN模型,转化完了之后进行测试,用的是MNN的Python接口。最开始的测试代码如下

import MNN
interpreter = MNN.Interpreter("test.mnn")
session = interpreter.createSession()
input_tensor = interpreter.getSessionInput(session)
tmp_input = MNN.Tensor((1,3,256,256), MNN.Halide_Type_Float, img, MNN.Tensor_DimensionType_Caffe)
input_tensor.copyFrom(tmp_input)
interpreter.runSession(session)
output_tensor = interpreter.getSessionOutput(session)
score_map = torch.Tensor(output_tensor.getData())

但是可视化的结果非常糟糕,而且将结果打印出来跟Pytorch模型的输出相比较差异也很大。为了验证自己转化模型的过程没有出错,我用resnet的一个分类模型转成mnn再用以上的代码测试,发现结果是正常的,跟pytorch模型输出基本一致。后来在调试的时候发现了一个问题,输入到mnn的图片跟输入到pytorch的图片不一致,然后发现是格式问题,pytorch用的是NCHW格式,MNN则会通过copyFrom把NCHW转化成NC4HW4,可想而知特征图也都是NC4HW4的了,所以要提取特征图的话需要先把格式转化回来。在网上找了几天的资料,终于把这个问题解决了,用的是MNN的一个包MNN.expr:
 

import MNN.expr as F
vars = F.load_as_dict("test.mnn")
inputVar = vars["input"]
if (inputVar.data_format == F.NC4HW4):
    inputVar.reorder(F.NCHW)
inputVar.write(test_img.tolist())
outputVar = vars['output']
if (outputVar.data_format == F.NC4HW4):
    outputVar = F.convert(outputVar, F.NCHW)

这样转化出来的特征图就能够正常使用了。值得注意的就是MNN.Tensor传入的图片是(3,256,256)的,但inputVar传入的是(1,3,256,256),需要用numpy增加多一个维度。

猜你喜欢

转载自blog.csdn.net/u013289254/article/details/120414559