Tensorflow2.X版本keras模型输出保存为frozen graph使用OpenCV调用

环境:

  1. windows10 64bit
  2. python: 3.7
  3. opencv 4.2.0
  4. tensorflow: 2.1

**目的:**利用opencv中的dnn模块对tensorflow模型进行加载。

opencv的dnn模块有函数dnn.readNetFromTensorflow,根据函数文档可知是调用pb格式的tensorflow模型,这里就入坑了,tensorflow保存的文件格式多种多种:TFLite, frozen graph, SavedModel, serving model, TFHub representation, Keras's .h5 ,tensorflow2.X版本之后推荐使用keras,原版keras默认保存的模型文件是.h5格式的,而tf.keras 模型的save方法默认保存格式是tensorflow的SavedModel(可以通过参数save_format控制),这种方法报道模型也有一个pb文件:
在这里插入图片描述
但如果使用dnn模块的相关接口去调用是无法正确读入模型的,经过查阅之后发现需要保存为frozen graph格式,目前网上搜到的大部分关于Keras模型转Frozen grpah的教程所依赖的tensorflow版本都较老,笔者使用的tf版本为最新的2.1版本,最终找到了一个靠谱的方法:

https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/

对应的代码在https://github.com/leimao/Frozen_Graph_TensorFlow/blob/master/TensorFlow_v2/train.py

下面笔者提供一份简单代码来演示opencv如何加载tensorflow(keras)模型:

from cv2 import dnn
import cv2
import numpy as np 
import matplotlib.pyplot as plt
import os
from keras import backend as K
from keras.models import load_model
#from tensorflow_serving.session_bundle import exporter
from keras.models import model_from_config
from keras.models import Sequential,Model
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os

print(tf.__version__)
print(cv2.__version__)
#%% opencv处理适合模型输入的图片
img_file = r"C:\Users\zhou-\Pictures\cat.jpg"
img_cv2 = cv2.imread(img_file)
print("[INFO]Image shape: ", img_cv2.shape)

# 主要图片尺寸要和模型输入匹配(mobilenet要求输入的尺寸为224*224)
inWidth = 224
inHeight = 224
blob = cv2.dnn.blobFromImage(img_cv2,
                                scalefactor=1.0 / 255,
                                size=(inWidth, inHeight),
                                mean=(0, 0, 0),
                                swapRB=False,
                                crop=False)
# blob = np.transpose(blob, (0,2,3,1)) # 适合keras mobilenet网络输入格式
print("[INFO]img shape: ", blob.shape)

#%% 保存keras模型为SaveModel会报错,相关issue见:
# https://github.com/opencv/opencv/issues/16582
model = tf.keras.applications.mobilenet.MobileNet(weights='imagenet')
# model.save('my_model', save_format='tf') # Save model to SavedModel format

# 参考https://github.com/leimao/Frozen_Graph_TensorFlow/blob/master/TensorFlow_v2/train.py

# Save model to SavedModel format
# tf.saved_model.save(model, r"./models")

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)

print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)

net = dnn.readNetFromTensorflow('frozen_models/frozen_graph.pb')

# Run a model
net.setInput(blob)
out = net.forward()

# Get a class with a highest score.
out = out.flatten()
classId = np.argmax(out)
confidence = out[classId]

# Put efficiency information.
t, _ = net.getPerfProfile()
label = 'Inference time: %.2f ms' % (t * 1000.0 / cv2.getTickFrequency())
cv2.putText(img_cv2, label, (0, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))

# Print predicted class.
def load_imagenet_classes(file_path):
    '''
    imagenet对应的标签数据如下所示:
    0: 'tench, Tinca tinca',
    1: 'goldfish, Carassius auratus',
    ...
    '''
    classes = []
    contents = None
    with open(file_path,'r') as f:
        contents = f.readlines()
    for cnt in contents:
        cnt = cnt.strip()
        classes.append(cnt.split(':')[1].strip().replace(',',''))
    
    return classes
        
classes = load_imagenet_classes('imagenet_classes.txt')

label = '%s: %.4f' % (classes[classId] if classes else 'Class #%d' % classId, confidence)
cv2.putText(img_cv2, label, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))

cv2.imwrite('output-{}.png'.format(img_file.split('\\')[-1][:-4]), img_cv2)

猜你喜欢

转载自blog.csdn.net/dou3516/article/details/110872600