使用高版本的AI引擎训练,导出模型后转换成Ascend310芯片的OM格式时,有可能遇到算子不支持的情况出现,现在教大家如何合理规避这些算子。
以在TensorFlow-2.x上训练得到的模型为例,如何转换成低版本Ascend310芯片(如C32版本)可用的OM模型。更多的技巧通过这篇文章可以举一反三,灵活变通。
写在前面
由于Frozen Graph已经被TF-2.x抛弃,TF-2.x开始使用keras模型,导出是saved_model格式或者h5格式。想要转换OM模型,首先要得到TensorFlow-1.x上的Frozen Graph模型
在TF-2.x下导出Frozen Graph
假设你有一个TF-2.x下的keras model
model = tf.keras.Model(input_nodes, output_nodes)
通过以下这段代码转换成Frozen Graph
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_graph_def = frozen_func.graph.as_graph_def()
# remove final nodes generated by keras which having so many indepencies inputs.
# this will help model to be opened by netron and to be converted to OM
frozen_sub_graph_def = tf.compat.v1.graph_util.extract_sub_graph(
frozen_graph_def, dest_nodes=[out_node.name[:-2] for out_node in output_nodes])
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_sub_graph_def,
logdir="/tmp/frozen_graph/",
name="model.pb",
as_text=False)
keras导出的模型会在最后加上Indentity节点,并且已整个模型为依赖,会导致模型不能转OM,并且netron也加载不了。
中间加入以下方法
tf.compat.v1.graph_util.extract_sub_graph(...)
可以把最后的Identity节点有效移除
一、算子版本过高
转换OM的时候,可能会遇到这种算子不存在的错误,如`FusedBatchNormV3`,这是因为低版本可能只支持到`FusedBatchNorm`为止,没有V3这个版本。
这个时候其实只要通过编辑Frozen Graph文件,简单的替换PB模型文件中的算子名称,把`FusedBatchNormV3`替换成`FusedBatchNorm`就可以了,计算是一样的,不会影响精度,只会影响性能。同类型的还有`AddV2`替换成`Add`,或者其他的算子,如果能找到早期版本的对应算子,就能合理规避。
二、算子新增参数不支持
比如在Conv2D这个算子上,高版本的TF引擎会有explicit_paddings这个选项,并且会写到Graph里,这时候转换OM就会报错,提示explicit_paddings这个attribute找不到,这个时候,也是编辑Frozen Graph文件,将这个attr从Conv2D这个op中去掉。一般来说这种低版本没有,高版本新增的参数特性,为了向前兼容,都默认关闭的,所以去掉一个attr不会影响精度。
奉上以上两种方法,通过编辑Frozen Graph来规避的实现代码
import os
import tempfile
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
TMP_PBTXT = 'tmp_model.pbtxt'
TMP_COMPAT_PBTXT = 'tmp_compat_model.pbtxt'
def merge_line_by_line(fo, input_graph_def):
entry = 0
item_lines = []
for line in fo:
if not line.strip():
continue
item_lines.append(line)
if line.strip().endswith('{'):
entry += 1
elif line.strip().endswith('}'):
entry -= 1
if entry == 0:
text_format.MergeLines(item_lines, input_graph_def)
del item_lines[:]
def parse_input_graph_proto(input_graph, input_binary):
if not os.path.exists(input_graph):
raise ValueError('invalid input path')
input_graph_def = graph_pb2.GraphDef()
if input_binary:
with open(input_graph, 'rb') as f:
input_graph_content = f.read()
input_graph_def.ParseFromString(input_graph_content)
else:
with open(input_graph, 'r') as f:
merge_line_by_line(f, input_graph_def)
return input_graph_def
def compat_pb(input_pb_path, replace=True):
tmp_dir = tempfile.mkdtemp()
tmp_pbtxt_file = os.path.join(tmp_dir, TMP_PBTXT)
graph_def = parse_input_graph_proto(input_pb_path, input_binary=True)
tf.train.write_graph(graph_or_graph_def=graph_def, logdir=tmp_dir, name=TMP_PBTXT, as_text=True)
del graph_def
new_graph_def_str = ''
lines_to_cache = []
num_lines_to_skip = 0
with open(tmp_pbtxt_file, 'r') as f:
for line in f:
if num_lines_to_skip > 0:
num_lines_to_skip -= 1
continue
if 'attr {' in line.strip():
lines_to_cache.append(line)
continue
if line.strip().startswith('key: "explicit_paddings"'):
del lines_to_cache[:]
num_lines_to_skip = 5
continue
elif line.strip().startswith('key: "U"'):
del lines_to_cache[:]
num_lines_to_skip = 4
continue
elif line.strip().startswith('key: "half_pixel_centers"'):
del lines_to_cache[:]
num_lines_to_skip = 4
continue
if lines_to_cache:
new_graph_def_str += ''.join(lines_to_cache)
del lines_to_cache[:]
new_graph_def_str += line
new_graph_def_str = new_graph_def_str.replace('FusedBatchNormV3', 'FusedBatchNorm').replace('AddV2', 'Add')
tmp_compat_pbtxt_file = os.path.join(tmp_dir, TMP_COMPAT_PBTXT)
with open(tmp_compat_pbtxt_file, 'w') as f:
f.write(new_graph_def_str)
del new_graph_def_str
graph_def_compat = parse_input_graph_proto(tmp_compat_pbtxt_file, input_binary=False)
input_pb_dir, input_pb_name = os.path.split(input_pb_path)
output_pb_dir = input_pb_dir
if replace:
output_pb_name = input_pb_name
else:
output_pb_name = 'compat_' + input_pb_name
tf.train.write_graph(graph_or_graph_def=graph_def_compat, logdir=output_pb_dir, name=output_pb_name, as_text=False)
if __name__ == '__main__':
compat_pb('/tmp/model.pb', replace=False)
虽然这里有一些硬编码,不过作为一个线下用用的工具,能满足功能就行了~
这个代片段主要是为了删除half_pixel_centers这个attribute对应的一组proto描述
elif line.strip().startswith('key: "half_pixel_centers"'):
del lines_to_cache[:]
num_lines_to_skip = 4
如果读懂这个脚本,就能应对各种算子低版本兼容和属性删除了。
当前这个脚本已经可以应对很多TF-1.15向下兼容到Ascend310-C32的情况了。
三、算子本身不支持
比如leaky_relu在Ascend310-C32版本上是找不到算子实现的,那么这个时候只能通过用其他算子拼凑的方式替换了。这时候不能通过编译Frozen Graph文件来解决(过于复杂),推荐直接从源码修改。比如将
y =tf.nn.leaky_relu(x, alpha=alpha)
替换成
tf.maximum(alpha * x, x)
又例如mish激活函数找不到,那么可以用以下算子替换
y = x * tf.tanh(tf.math.log(1 + tf.exp(x)))
四、通过前/后处理规避
如果你的算子使用上述方式还不能支持,并且这个算子出现在模型的头上或者尾部,那恭喜你你还有希望。
你可以在导出模型的时候将算子涉及的这段计算从模型中拿出来,放到推理脚本的前处理后后处理。以伪代码举例
假设你的模型是:
def model(x):
y = op1(x)
y = op2(y)
y = op3(y)
假设op1和op3都不支持,并且都是一些不含有网络权重的计算,那么你导出模型的时候只导出op2部分,将op1用numpy的API写在预处理中,将op3用numpy的API写在后处理中
例如:
在声音分类中,模型的最前面要对数据进行傅里叶变换,但是傅里叶变换算子在Ascend310-C32上不支持,那么在导出模型的时候将傅里叶变换从模型的最开始摘除,然后用numpy实现,写在推理脚本的前处理
在物体检测中,模型的最后要对结果做NMS,其中涉及动态shape,在Ascend310-C32上不支持,那么导出模型直接摘除后处理,模型直接输出feature_map,然后在推理脚本的后处理做NMS(numpy的NMS也很快,不用担心性能)
from tensorflow.python.compat import compat
with compat.forward_compatibility_horizon(2019, 05, 01):
y = model(x)