使用TensorFlow一步步进行目标检测(5)

本教程进行到这一步,您选择了预训练的目标检测模型,转换现有数据集或创建自己的数据集并将其转换为TFRecord文件,修改模型配置文件,并开始训练模型。接下来,您需要保存模型并将其部署到项目中。

将检查点模型(.ckpt)保存为.pb文件

回到TensorFlow目标检测文件夹,并将export_inference_graph.py文件复制到包含模型配置文件的文件夹中。

python export_inference_graph.py --input_type image_tensor --pipeline_config_path ./rfcn_resnet101_coco.config --trained_checkpoint_prefix ./models/train/model.ckpt-5000 --output_directory ./fine_tuned_model

这将创建一个新目录fine_tuned_model,里面名为frozen_inference_graph.pb的模型就是您训练出来的模型。

在项目中使用模型

我在本教程中一直在研究的项目是创建一个红绿灯分类器。在Python中,我将此分类器实现为一个类。 在类的初始化部分,我创建了一个TensorFlow会话,这样就不需要在每次需要分类时创建它。

class TrafficLightClassifier(object):
    def __init__(self):
        PATH_TO_MODEL = 'frozen_inference_graph.pb'
        self.detection_graph = tf.Graph()
        with self.detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            # Works up to here.
            with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
            self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
            self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
            self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
            self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
            self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0')
        self.sess = tf.Session(graph=self.detection_graph)

在该类中,我创建了一个函数,该函数对图像进行分类,并返回图像中分类的边界框、分数和类别。

def get_classification(self, img):
    # Bounding Box Detection.
    with self.detection_graph.as_default():
        # Expand dimension since the model expects image to have shape [1, None, None, 3].
        img_expanded = np.expand_dims(img, axis=0)  
        (boxes, scores, classes, num) = self.sess.run(
            [self.d_boxes, self.d_scores, self.d_classes, self.num_d],
            feed_dict={self.image_tensor: img_expanded})
    return boxes, scores, classes, num

此时,您需要过滤低于指定分数阈值的结果。结果会自动从最高分数到最低分数排序,因此这很容易实现。通过上面的函数返回分类结果,就是这样,您做到了!

您可以在下图中看到我实现的红绿灯分类器。

image

我最初创建本教程是因为我很难找到有关如何使用Object Detection API的资讯。我希望通过阅读本教程,您可以启动项目,让项目快速实现,这样您可以将更多时间集中在您真正感兴趣的内容上!

相关文章

  1. 使用TensorFlow一步步进行目标检测(1)
  2. 使用TensorFlow一步步进行目标检测(2)
  3. 使用TensorFlow一步步进行目标检测(3)
  4. 使用TensorFlow一步步进行目标检测(4)

image

猜你喜欢

转载自blog.csdn.net/mogoweb/article/details/81324164