文件介绍
synset.txt:标签列表
vgg16-20160129.tfmodel:pre-trained vgg16的网络结构和结点参数
定义输入placeholder
images = tf.placeholder("float", [None, 224, 224, 3])
加载模型
with open("model/vgg16-20160129.tfmodel", mode='rb') as f:
fileContent = f.read()
创建Graph,导入pre-trained模型
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
tf.import_graph_def(graph_def, input_map={ "images": images })
graph = tf.get_default_graph()
定义feed_dict
for i in ['cat.jpg','airplane.jpg','zebra.jpg','pig.jpg']:
img=load_image('model/pic/'+i)
plt.imshow(img)
plt.show()
imgs.append(img)
img_num=len(imgs)
batch = np.array(imgs).reshape((img_num, 224, 224, 3))
assert batch.shape == (img_num, 224, 224, 3)
feed_dict = { images: batch }
进行预测
prob_tensor = graph.get_tensor_by_name("import/prob:0")
prob = sess.run(prob_tensor, feed_dict=feed_dict)
结果展示如下
完整代码在我的GitHub上:https://github.com/mjDelta/tensorflow-examples/blob/master/load_vgg16.py
pre-trained model百度云分享 链接:https://pan.baidu.com/s/1mhEzH4s 密码:u7ap