使用pycaffe可视化caffemodel

#coding=utf-8
import numpy as np
import sys, os
import cv2

caffe_root = '/home/sam/sam/caffe_dssd_sam/'
sys.path.insert(0, caffe_root + 'python')
import caffe
import time


# net_file = 'MobileNetSSD_test.prototxt'
net_file = '/home/sam/sam/caffe_dssd_sam/models/ResNet-101/car_game/SSD_car_game_321x321/deploy.prototxt'


caffe_model='/home/sam/sam/caffe_dssd_sam/models/ResNet-101/car_game/SSD_car_game_321x321/ResNet-101_car_game_SSD_car_game_321x321_iter_10000.caffemodel'

test_img_dir = "/home/sam/car_game/test_a"

# caffe.set_mode_cpu()
net = caffe.Net(net_file, caffe_model, caffe.TEST)

#CLASSES = ('background', 'bicycle','motorbike', 'car','bus', 'truck')
CLASSES = ('background','car')
def postprocess(img, out):
    h = img.shape[0]
    w = img.shape[1]
    box = out['detection_out'][0, 0, :, 3:7] * np.array([w, h, w, h])

    cls = out['detection_out'][0, 0, :, 1]
    conf = out['detection_out'][0, 0, :, 2]
    return (box.astype(np.int32), conf, cls)


def detect(tese_img_dir):
    if 'gout_result' not in os.listdir(test_img_dir):
        os.mkdir(test_img_dir+'/gout_result')
    img_name_list=[img for img in os.listdir(test_img_dir) if img.endswith('.jpg')]
    for cur_img in img_name_list:
        origimg = cv2.imread(test_img_dir+'/'+cur_img)
        img = cv2.resize(origimg, (321, 321))
        img = img - [104, 117, 123]
        img = img.astype(np.float32)
        img = img.transpose((2, 0, 1))

        net.blobs['data'].data[...] = img

        time_start = time.time()
        out = net.forward()
        time_end = time.time()
        print "{} time consuming:{}".format(cur_img,time_end - time_start)

        box, conf, cls = postprocess(origimg, out)

        for i in range(len(box)):

            p1 = (box[i][0], box[i][1])
            p2 = (box[i][2], box[i][3])
            cv2.rectangle(origimg, p1, p2, (0, 255, 0))
            p3 = (max(p1[0], 15), max(p1[1], 15))#left conner handle
            class_index=int(cls[i])
            if class_index==-1:
                class_index=0
            title = "%s:%.2f" % (CLASSES[class_index], conf[i])
            print title
            if conf[i]<0.2:
                cv2.putText(origimg, title, p3, cv2.FONT_ITALIC, 0.6, (0, 255, 0), 1)
        cv2.imshow("detect_car", origimg)
        cv2.imwrite(test_img_dir+'/gout_result/'+cur_img,origimg)
        cv2.waitKey(10)


if __name__=='__main__':
    detect(test_img_dir)

猜你喜欢

转载自blog.csdn.net/touch_dream/article/details/80773001