前言
毕设准备做个深度学习相关的课题,应用到自动化领域.利用寒假看了几本书结合自己的知识对多层感知机和浅层卷积神经网络训练MNIST数据集有一定经验.但是对于强化学习,对抗生成网络,深度残差网络还是知其名不解其意,更无从谈起,于是准备通过学习GitHub上的项目来了解一下.之前了解了一点语义分割了解到谷歌的DeepLabv3+框架很先进,就拿它作为突破口,.
环境
从GItHub上搜索DeepLabv3+,排在第一个是tensorflow-deeplab-v3-plus下载zip压缩包或者使用git进行克隆,命令为: git clone https://github.com/rishizek/tensorflow-deeplab-v3-plus.git
.
这里我使用的是Windows10系统,命令行使用cmder替代cmd,也可以使用powershell.GPU环境为CUDA10.0+cuDNN+tensorflow-nightly-gpu,部署记录在我的这篇博客.
Linux的科学上网一直搞不定,不然下载文件会是件头疼的事,只得在Windows上折腾.
项目部署:
首先我将参数配置罗列成一个yaml文件:
tf_record:
data_dir: ./dataset/VOCdevkit/VOC2012
output_path: ./dataset
train_data_list: ./dataset/train.txt
valid_data_list: ./dataset/val.txt
image_data_dir: JPEGImages
label_data_dir: SegmentationClassAug
evaluat:
image_data_dir: dataset/VOCdevkit/VOC2012/JPEGImages
label_data_dir: dataset/VOCdevkit/VOC2012/SegmentationClassAug
evaluation_data_list: ./dataset/val.txt
model_dir: ./model
base_architecture: resnet_v2_101
output_stride: 16
export_inference_graph:
model_dir: ./model
export_dir: dataset/export_output
base_architecture: resnet_v2_101
output_stride: 16
inference:
data_dir: dataset/VOCdevkit/VOC2012/JPEGImages
output_dir: ./dataset/inference_output
infer_data_list: ./dataset/sample_images_list.txt
model_dir: ./model
base_architecture: resnet_v2_101
output_stride: 16
debug: store_true
tran:
model_dir: ./model
train_epochs: 26
epochs_per_eval: 1
tensorboard_images_max_outputs: 6
batch_size: 10
learning_rate_policy: poly
max_iter: 30000
data_dir: ./dataset/
base_architecture: resnet_v2_101
pre_trained_model: ./ini_checkpoints/resnet_v2_101/resnet_v2_101.ckpt
output_stride: 16
initial_learning_rate: 7e-3
end_learning_rate: 1e-6
initial_global_step: 0
weight_decay: 2e-4
步骤
- Download and extract PASCAL VOC training/validation data (2GB tar file), specifying the location with the --data_dir.
- Download and extract augmented segmentation data (Thanks to DrSleep), specifying the location with --data_dir and --label_data_dir (namely, label_data_dir).
- For inference the trained model with 77.31% mIoU on the Pascal VOC 2012 validation dataset is available here. Download and extract to --model_dir.
- For training, you need to download and extract pre-trained Resnet v2 101 model from slim specifying the location with --pre_trained_model.
解释一下
- 下载PASCAL VOC training/validation data数据集,解压到data_dir参数对应路径.(直接解压到dataset文件夹内即可)
- 下载augmented segmentation data,解压到data_dir文件夹下的label_data_dir对应路径.(解压到\dataset\VOCdevkit\VOC2012文件夹里面即可)
- 下载trained model解压到model_dir对应路径(解压到./model文件夹即可)
- 下载pre-trained Resnet v2 101 model解压到pre_trained_model对应路径(需要按照路径参数创建./ini_checkpoints/resnet_v2_101/目录,然后解压)
数据转换
使用python create_pascal_tf_record.py --data_dir DATA_DIR --image_data_dir IMAGE_DATA_DIR --label_data_dir LABEL_DATA_DIR
,这里直接默认参数,使用python create_pascal_tf_record.py
即可.根据经验参数命令应该有=
号,这里直接空格,可能会报错.
训练模型
使用python train.py --model_dir MODEL_DIR --pre_trained_model PRE_TRAINED_MODEL
,数据如上配置的话使用python train.py
使用默认参数即可.
使用模型
使用python inference.py --data_dir DATA_DIR --infer_data_list INFER_DATA_LIST --model_dir MODEL_DIR
可以使用模型预测,或者使用默认参数