当我们拿到别人提供的pb模型时,我们很可能不清楚inputs所需的参数,可以使用以下方法来查看模型信息:
- 找到你的tensorflow安装位置,依次定位到以下目录:
tensorflow_core/python/tools - 执行如下命令:
python saved_model_cli.py show --dir 模型路径 --all
之后可以得到如下信息:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['image_shape'] tensor_info:
dtype: DT_FLOAT
shape: (2)
name: Placeholder_366:0
inputs['input_image'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, -1, 3)
name: input_1:0
The given SavedModel SignatureDef contains the following output(s):
outputs['concat_11/concat:0'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 4)
name: concat_11/concat:0
outputs['concat_12/concat:0'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: concat_12/concat:0
outputs['concat_13/concat:0'] tensor_info:
dtype: DT_INT32
shape: (-1)
name: concat_13/concat:0
Method name is: tensorflow/serving/predict
由上面可知:
- 该pd有一个签名信息,且签名名称为serving_default
- 有两个inputs,因此需要传递两个参数,其中image_shape数据为一个一维数组,里面有两个元素;input_image为一个四维数组
- 有三个输出,concat_11/concat:0为一个1*4的矩阵,concat_12/concat:0为一个float型数据,concat_13/concat:0为一个int型数据。