第五期 基于 Inception-V3 重新训练网络 《显卡就是开发板》

版权声明:本文为aggresss原创文章,未经博主允许不得转载。 作者:[email protected] https://blog.csdn.net/aggresss/article/details/78541465

  Tensorflow官方有很多值得实践一下的例子,这一期我们使用Tensorflow官方提供的一个重新训练网络的例子进行演示,下面时这个例子的链接:
  https://www.tensorflow.org/versions/master/tutorials/image_retraining
  https://codelabs.developers.google.com/codelabs/tensorflow-for-poets

实验很简单,一共四个步骤,通过修改 inception_v3 模型的最后两层来使网络适应新的数据集和分类。
Step 1: 下载tensorflow源文件

cd ~
git clone https://github.com/tensorflow/tensorflow
cd tensorflow

Step2: 下载数据集

cd ~
curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz

Step3: 重新训练网络

cd ~/tensorflow
python tensorflow/examples/image_retraining/retrain.py --image_dir ~/flower_photos

Step4: 验证训练后的网络

cd ~/tensorflow
python tensorflow/examples/label_image/label_image.py \
--graph=/tmp/output_graph.pb \
--labels=/tmp/output_labels.txt \
--input_layer=Mul \
--output_layer=final_result \
--image=$HOME/flower_photos/daisy/21652746_cc379e0eea_m.jpg

可以得到类似下面的输出:

daisy 0.997363
sunflowers 0.0018164
dandelion 0.00066717
tulips 0.00013287
roses 2.09211e-05

  具体的实验细节可以参考通过查看源码了解,主要就是 retrain.py 和 label_images.py,在执行retain.py的时候,自动下载了inception_v3模型基于ImageNet数据集训练好的文件,默认情况下在/tmp/imagenet/ 目录下,需要注意的是/tmp 文件夹是临时文件夹,默认情况下Ubuntu会在每次开机时进行清理,所以如果需要请及时备份里面的文件,或者给retrain.py 传入 –model-dir 参数修改目录。nception-2015-12-05.tgz 就是已下载的模型文件,解压后的 classify_image_graph_def.pb 就是已经固化的可以直接使用的文件。
tensorflow的训练结果持久化有两种方式:
  1.如果网络已经训练完毕需要固化时使用 tf.gfile.GFile 将所有数据保存为 Protocol Buffer 格式的 *.pb 文件;
  2. 当网络训练未完成,但需要暂时保存时使用 tf.train.Saver 保存为 checkpoint 格式的文件。
/tmp/imagenet/classify_image_graph_def.pb 就是已经freezen的网络的二进制体现,当然tensorflow 也为它提供了一种可视化的查看方式– tensorboard ,通过以下命令可以查看classify_image_graph_def.pb 文件中的 graph 结构

cd ~/tensorflow
python tensorflow/python/tools/import_pb_to_tensorboard.py --model_dir=/tmp/imagenet/classify_image_graph_def.pb --log_dir=/tmp/classify_image_graph
tensorboard --logdir=/tmp/classify_image_graph

然后打开浏览器 http://localhost:6006 ,在 graph 中就可以查看这个网络的结构。

然后,我们再来看一下文章开始时重新训练的网络的结构

tensorboard --port 6007 --logdir /tmp/retrain_logs

注意为了避免和刚才的tensorboard端口重复,增加了 –port 6007 参数,打开浏览器 http://localhost:6007 可以和刚才没有修改的网络做一下对比,查看做了哪些修改。

这里写图片描述

备注: retrain.py 的默认参数列表

--image_dir=''
--output_graph'='/tmp/output_graph.pb'
--intermediate_output_graphs_dir='/tmp/intermediate_graph/'
--intermediate_store_frequency=0
--output_labels='/tmp/output_labels.txt'
--summaries_dir='/tmp/retrain_logs'
--how_many_training_steps=4000
--learning_rate=0.01
--testing_percentage=10
--validation_percentage=10
--eval_step_interval=10 
--train_batch_size=100
--test_batch_size=1
--validation_batch_size=100
--print_misclassified_test_images=False
--model_dir'='/tmp/imagenet'
--bottleneck_dir='/tmp/bottleneck'
--final_tensor_name='final_result'
--flip_left_right=False
--random_crop=0
--random_scale=0
--random_brightness=0
--architecture'='inception_v3'

猜你喜欢

转载自blog.csdn.net/aggresss/article/details/78541465