记录一次使用inception-v3迁移学习训练自己的图片分类器

第一步生成自己的模型

首先要将收集的图片按标签分类放到不同的文件夹下如图:

其中有五中花的分类:daisy,dandelion,roses,sunflowers,tulips这五种花的分类每个分类下有500张该类图片

# github上下载TensorFlow 找到tensorflow\examples\image_retraining\retrain.py 可能目前转移到了hub中而且这个需要翻墙才能用
# 编写dos命令执行脚本:参数image_dir指定
# python C:\Users\admin\PycharmProjects\TensorFlowTestNew\hub-master\hub-master\examples\image_retraining\retrain.py --image_dir flower_photos
# 由于TensorFlow上最新的retrain.py要去寻找一个默认的在线的模型(tfhub_module的默认值),但是那个模型国内没有翻墙vpn无法访问到,
# 所以在网上找了个其他版本的retrain.py(https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/scripts/retrain.py)能适用于TensorFlow_1.13.1,
# 并将里面的报错(报错原因是因为新旧版本方法位置变化了,当前使用的TensorFlow是1.13.1)予以修改,也可能不报错,之后得到了一个我自己的myretrain.py执行下边的脚本即可
# python C:\Users\admin\PycharmProjects\TensorFlowTestNew\TensorFlow\inception利用\myretrain.py --image_dir flower_photos --how_many_training_steps 200 --model_dir C:\Users\admin\PycharmProjects\TensorFlowTestNew\TensorFlow\inception_pretrain --output_graph output_graph.pd --output_labels output_labels.txt
# --image_dir flower_photos 要分类的图片地址
# --how_many_training_steps 200 训练周期是200,默认好像是4000
# --model_dir C:\Users\admin\PycharmProjects\TensorFlowTestNew\TensorFlow\inception_pretrain 我们要使用的模型的所在位置那个tgz包的位置,不指定回去自己下载一个
# --output_graph output_graph.pd 训练好的我自己的分类模型
# --output_labels output_labels.txt  输出一个标签位置
# 还有很多参数都有默认设置,到多数的输出文件位置都是默认在所在盘符的\tmp下,可以去代码中 argparse.ArgumentParser 下查看

测试使用自己的新模型

# coding: UTF-8
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt

# 创建一个图来存放google调整好的模型 inception_pretrain\classify_image_graph_def.pb
# 结果数组与C:\Users\admin\PycharmProjects\TensorFlowTestNew\TensorFlow\inception利用\output_labels.txt文件中的顺序要一致
res = ['daisy','dandelion']
with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')#获取新模型最后的输出节点叫做final_result,可以从tensorboard中的graph中看到,其中名字后面的’:’之后接数字为EndPoints索引值(An operation allocates memory for its outputs, which are available on endpoints :0, :1, etc, and you can think of each of these endpoints as a Tensor.),通常情况下为0,因为大部分operation都只有一个输出。
    # 遍历目录
    for root, dirs, files in os.walk('testImage/'):#预测图片的位置
        for file in files:
            image_data = tf.gfile.FastGFile(os.path.join(root, file), 'rb').read()#Returns the contents of a file as a string.
            predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})#tensorboard中的graph中可以看到DecodeJpeg/contents是模型的输入变量名字
            predictions = np.squeeze(predictions)

            image_path = os.path.join(root, file)
            print(image_path)
            #展示图片
            # img = plt.imread(image_path)#只能读png图,所以不能显示其他图片,训练非png图时把这段注释掉,他只是一个显示作用
            # plt.imshow(img)
            # plt.axis('off')
            # plt.show()

            top_k = predictions.argsort()[-2:][::-1]#概率最高的后2个,然后在倒排一下
            for node_id in top_k:
                score = predictions[node_id]
                print('%s (score=%.5f)' % (res[node_id], score))
            print()

结果:

均正确

猜你喜欢

转载自blog.csdn.net/qq_16320025/article/details/89154488