基于GoogLeNet的不同花分类微调训练案例

import tensorflow as tf
from tensorflow.contrib.slim import nets
slim = tf.contrib.slim
import numpy as np
/root/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
class GoogLeNet(object):   
    
    def __init__(self, lr, batch_size, iter_num):
        self.lr = lr   # 学习率
        self.batch_size = batch_size
        self.iter_num = iter_num   # 总共训练多少次
        
        tf.reset_default_graph()   # 重置图。有时候大家运行程序时候会提示某某tensor已经被构造。这是因为之前创建的图还在,然后重新运行一遍代码又创建了一个新图。可以在这里加一句tf.reset_default_graph()
        
        self.X = tf.placeholder(tf.float32, [None, 224, 224, 3])
        self.y = tf.placeholder(tf.float32, [None, 17])   # 17flowersu数据集有17个类
        self.dropRate = tf.placeholder(tf.float32)    
        
        with slim.arg_scope(nets.inception.inception_v1_arg_scope()):
            net, endpoints = nets.inception.inception_v1(self.X, num_classes=1001)        
            # 在这里,我们直接使用预置的模型。
        net = endpoints['Mixed_5c']
        net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
        net = tf.reshape(net , [-1, 1024])
        # 下面这些,大家应该非常熟悉了,和MNIST的一样的
        net = tf.nn.dropout(net, self.dropRate)
        logits = tf.layers.dense(net, 17, use_bias=True,
                                 kernel_initializer=tf.constant_initializer(0),
                                 bias_initializer=tf.constant_initializer(0))
        self.logits = logits
        self.loss = tf.losses.softmax_cross_entropy(onehot_labels=self.y, logits=logits)
        self.train_step = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss)
        
        # 用于模型训练
        self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(logits, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
        
#         用于保存训练好的模型
        self.saver = tf.train.Saver()
    
        summary_loss = tf.summary.scalar('loss', self.loss)
        summary_accuracy = tf.summary.scalar('accuracy', self.accuracy)
        self.merged_summary_op = tf.summary.merge_all()

    def read_image_label_list(self, name_list):
        # 读取图像文件和标注列表
        
        img_list=[]
        label_list=[]
                
        with open(name_list) as fr:
            for line in fr.readlines():
                imgIndex = int(line.strip())
                imgLabel = int(imgIndex / 80)
                imgPath = 'data/jpg/image_%04d.jpg' % imgIndex
                img_list.append(imgPath)                
                label_list.append(imgLabel)                

        return img_list, label_list      

    def read_file(self, name_list):
        image_list, label_list = self.read_image_label_list(name_list)
        imagepaths, labels = tf.train.slice_input_producer([image_list, label_list], shuffle=True)
        image = tf.read_file(imagepaths)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize_images(image, [224, 224])
        image = tf.image.random_brightness(image, 15)
        image = tf.image.random_flip_left_right(image)
        image = (image * 1.0 / 127.5 - 1)
        label = tf.one_hot(labels, 17)
        X, Y = tf.train.batch([image, label], batch_size=self.batch_size, num_threads=2, capacity=self.batch_size*4)    
        return X, Y
    
    def train(self):
        training_images, training_labels = self.read_file('trn1.txt')
        test_images, test_labels = self.read_file('val1.txt')
        
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)      
            variables_to_restore = slim.get_variables_to_restore()
            init_fn = slim.assign_from_checkpoint_fn(r'pre_trained/inception_v1.ckpt',
                                         variables_to_restore,
                                         ignore_missing_vars=True)
            init_fn(sess)
            
            summary_writer = tf.summary.FileWriter('log/train_base', sess.graph)
            summary_writer_test = tf.summary.FileWriter('log/test_base')

            for i in range(self.iter_num):   
                tf.local_variables_initializer().run()
                images, labels = sess.run([training_images, training_labels])  
               
                feed_dict = {self.dropRate: 0.5,
                         self.X :images,
                         self.y :labels}           
                loss, _ = sess.run([self.loss, self.train_step], 
                                  feed_dict=feed_dict)   # 每调用一次sess.run,就像拧开水管一样,所有self.loss和self.train_step涉及到的运算都会被调用一次。

                if i%10 ==0:
                    images, labels = sess.run([training_images, training_labels]) 
                    train_accuracy, summary_str = sess.run([self.accuracy,self.merged_summary_op], feed_dict={self.X: images, self.y: labels, self.dropRate: 1.})  # 把训练集数据装填进去
                    summary_writer.add_summary(summary_str, i)   
                    images, labels = sess.run([test_images, test_labels])
                    test_accuracy, summary_str = sess.run([self.accuracy,self.merged_summary_op], feed_dict={self.X: images, self.y: labels, self.dropRate: 1.})  # 把测试集数据装填进去
                    summary_writer_test.add_summary(summary_str, i)
                    print ('iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy, test_accuracy))

            self.saver.save(sess, 'model/flowerModel') # 保存模型
            summary_writer.flush()
            summary_writer_test.flush()
            coord.request_stop()
            coord.join(threads)
            
    def test(self):
        test_images, test_labels = self.read_file('tst1.txt')
        with tf.Session() as sess:
            self.saver.restore(sess, 'model/flowerModel')
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)     
            Accuracy = []
            for i in range(int(340/self.batch_size) + 1):
                images, labels = sess.run([test_images, test_labels])
                test_accuracy = sess.run(self.accuracy, feed_dict={self.X: images, self.y: labels, self.dropRate: 1.})  # 把测试集数据装填进去
                Accuracy.append(test_accuracy)
            print('==' * 15) 
            print( 'Test Accuracy: ', np.mean(np.array(Accuracy))   ) 
            coord.request_stop()
            coord.join(threads)
model = GoogLeNet(0.1, 50, 100)
model.train()
model.test()
WARNING:tensorflow:From <ipython-input-2-7ce60d3cb483>:18: calling reduce_mean (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
WARNING:tensorflow:Variable dense/kernel missing in checkpoint pre_trained/inception_v1.ckpt
WARNING:tensorflow:Variable dense/bias missing in checkpoint pre_trained/inception_v1.ckpt
INFO:tensorflow:Restoring parameters from pre_trained/inception_v1.ckpt
iter    0   loss    2.833214    train_accuracy  0.020000    test_accuracy   0.100000
iter    10  loss    1.716118    train_accuracy  0.580000    test_accuracy   0.760000
iter    20  loss    0.940882    train_accuracy  0.940000    test_accuracy   0.800000
iter    30  loss    0.329169    train_accuracy  0.960000    test_accuracy   0.860000
iter    40  loss    0.229579    train_accuracy  1.000000    test_accuracy   0.900000
iter    50  loss    0.096816    train_accuracy  1.000000    test_accuracy   0.940000
iter    60  loss    0.138667    train_accuracy  1.000000    test_accuracy   0.900000
iter    70  loss    0.133150    train_accuracy  1.000000    test_accuracy   0.940000
iter    80  loss    0.048020    train_accuracy  1.000000    test_accuracy   0.920000
iter    90  loss    0.057278    train_accuracy  1.000000    test_accuracy   0.880000
INFO:tensorflow:Restoring parameters from model/flowerModel
==============================
Test Accuracy:  0.94285715

猜你喜欢

转载自www.cnblogs.com/shayue/p/10390712.html