DeepLab V3+ 训练自己的数据

一、前提

官网代码:https://github.com/tensorflow/models/tree/master/research/deeplab 

1. 依赖

DeepLab依赖于以下库:

  • Numpy
  • Pillow 1.0
  • tf Slim (which is included in the "tensorflow/models/research/" checkout)
  • Jupyter notebook
  • Matplotlib
  • Tensorflow1.6及以上

2. 将库添加到PYTHONPATH

在本地运行时,tensorflow / models / research /和slim目录应该附加到PYTHONPATH。 这可以通过在 tensorflow / models / research /路径下运行以下命令来完成:

# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

注意:每次启用新终端,此命令都需要运行。 如果想避免手动运行,可以将其作为新行添加到〜/ .bashrc文件的末尾。

3. 简单测试

可以通过运行以下命令来测试是否已成功安装 Tensorflow DeepLab:

运行 model_test.py 进行快速测试:

# From tensorflow/models/research/
python deeplab/model_test.py

在PASCAL VOC 2012数据集上快速运行整个代码:

# From tensorflow/models/research/deeplab
sh local_test.sh

local_tesr.sh 脚本用于在PASCAL VOC 2012上运行本地测试。

之后在自己数据集上进行训练等操作就可以参照 local_test.sh 来编辑指令。打开脚本看一下,发现它:

(1)执行了model_test.py

(2)执行了download_and_convert_voc2012.sh

(3)从model_zoo(http://download.tensorflow.org/models)下载了模型deeplabv3_pascal_train_aug

(4)执行了train.py

(5)执行了eval.py

(6)执行了vis.py

(7)执行了export_model.py

建议仔细阅读上面提到的脚本和程序,为以后训练自己的数据提供参考。

错误1:

运行 model_test.py 进行快速测试时出错:

参考 https://github.com/tensorflow/models/issues/5523

 将 model_test.py 中140行左右的:

self.assertListEqual(scales_to_model_results.keys(),

修改为:

self.assertListEqual(list(scales_to_model_results.keys()),

错误2:

测试程序需要运行eval.py,我在这里出现了一个错误:

即:InvalidArgumentError (see above for traceback): assertion failed: [`predictions` out of bound] [Condition x < y did not hold element-wise:] [x (mean_iou/confusion_matrix/control_dependency_1:0) = ] [0 0 0...] [y (mean_iou/ToInt64_2:0) = ] [21]

参考https://github.com/tensorflow/models/issues/4203中trobr的说法:

修改 eval.py 中第145 行左右:

将:

metric_map = {}

metric_map[predictions_tag] = tf.metrics.mean_iou(

        predictions, labels, dataset.num_classes, weights=weights)

修改为:   也就是中间插入了几行

 metric_map = {}

    # insert by trobr

    indices = tf.squeeze(tf.where(tf.less_equal(

        labels, dataset.num_classes - 1)), 1)

    labels = tf.cast(tf.gather(labels, indices), tf.int32)

    predictions = tf.gather(predictions, indices)

    # end of insert

    metric_map[predictions_tag] = tf.metrics.mean_iou(

        predictions, labels, dataset.num_classes, weights=weights)

二、数据准备 

参照VOC2012的文件结构,把自己的数据和文件夹准备好。

参考download_and_convert_voc2012.sh进行数据转化。

1. label图修改(也许需要)

label图应该是没有色彩的,类别的像素标记应该是0,1,2,3......

注意:不要把 ignore_label background 混淆,ignore_label是没有做标注的,不在预测范围内的,ignore_label是不参与计算loss的。我们在mask中将 ignore_label 的灰度值标记为255,而background 标记为0

如果是voc2012这种有colormap的标签图,可以利用remove_gt_colormap.py先去掉colormap:

# from research/deeplab/datasets
python remove_gt_colormap.py \
  --original_gt_folder="/path/SegmentationClass" \
  --output_dir="/path/SegmentationClassRaw"

其中, original_gt_folder是原始标签图文件夹,output_dir是要输出的标签图文件夹的位置。

2. 数据转换为tfrecord

# from research/deeplab/datasets
python build_voc2012_data.py \
  --image_folder="/path/JPEGImages" \
  --semantic_segmentation_folder="/path/SegmentationClassRaw" \
  --list_folder="/path/ImageSets/Segmentation" \
  --image_format="jpg" \
  --output_dir="/path/tfrecord"

其中,image_folder是jpg原图文件夹,semantic_segmentation_folder是转化后label图文件夹,list_folder是train.txt、val.txt、trainval.txt所在的文件夹,output_dir是输出数据存放的文件夹。

转换后的数据保存到tfrecord(tfrecord文件夹事先建好)。

三、训练准备

1. 修改segmentation_dataset.py(注册数据集)

(1)在这段代码注册数据集,使我的数据集 voc_turbulent 拥有姓名:

_DATASETS_INFORMATION = {

    'cityscapes': _CITYSCAPES_INFORMATION,

    'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,

    'ade20k': _ADE20K_INFORMATION,

    'voc_turbulent': _VOC_TURBULENT_INFORMATION,

}

(2)参照代码中其他数据集形式,加入个人数据集描述配置:

训练、检测数据的数量修改好,类别数量也根据实际修改。

_VOC_TURBULENT_INFORMATION = DatasetDescriptor(

    splits_to_sizes={

        'train': 1413,

        'trainval': 2826,

        'val': 1413,

    },

    num_classes=21,

    ignore_label=255,

)

2. 修改train_utils.py

文件修改如下:

exclude_list = ['global_step'] 

修改为:

exclude_list = ['global_step', 'logits'] 

作用是在使用预训练权重时候,不加载该logit层。训练自己的数据集时,此处进行修改。

四、训练

模型从官网下载:https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md

python deeplab/train.py \
  --logtostderr \
  --train_split="train" \
  --model_variant="xception_65" \
  --dataset="voc_turbulent" \#前面注册的数据集名称
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --train_crop_size=513 \
  --train_crop_size=513 \
  --training_number_of_steps=90000  \
  --base_learning_rate=0.0001 \
  --num_clones=3 \#3块显卡
  --train_batch_size=9 \#得是显卡数量的倍数哈
  --fine_tune_batch_norm=false \
  --initialize_last_layer=False \
  --last_layers_contain_logits_only=True \
  --tf_initial_checkpoint="/path/deeplabv3_pascal_train_aug/model.ckpt" \
  --train_logdir="/path/exp/train_on_train_set/train" \
  --dataset_dir="/path/tfrecord"

注意:

(1)学习率

 (2)batch size

(3)模型选择及参数

(4)crop size 

(5)关于initialize_last_layer和last_layers_contain_logits_only

五、验证 

python deeplab/eval.py \
  --logtostderr \
  --eval_split="val" \
  --model_variant="xception_65" \
  --dataset="voc_turbulent" \
  --num_clones=3 \#3块显卡
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --eval_crop_size=513 \
  --eval_crop_size=513 \
  --checkpoint_dir="/path/exp/train_on_train_set/train" \
  --eval_logdir="/path/exp/train_on_train_set/eval" \
  --dataset_dir="/path/tfrecord" \
  --max_number_of_evaluations=1

结果不是很好: 

六、可视化

python deeplab/vis.py \
  --logtostderr \
  --vis_split="val" \
  --model_variant="xception_65" \
  --dataset="voc_turbulent" \
  --num_clones=3 \
  --atrous_rates=6 \
  --atrous_rates=12 \
  --atrous_rates=18 \
  --output_stride=16 \
  --decoder_output_stride=4 \
  --vis_crop_size=513 \
  --vis_crop_size=513 \
  --checkpoint_dir="/xxx/exp/train_on_train_set/train" \
  --vis_logdir="/xxx/exp/train_on_train_set/vis" \
  --dataset_dir="/xxx/tfrecord" \
  --max_number_of_iterations=1

七、预测单张图片 

在deeplab_demo.ipynb的基础上做些修改,为方便使用,给出网盘链接,使用时修改路径即可。

链接:https://pan.baidu.com/s/16iffY6WkOwjRezttAuulFQ 
提取码:06fy 
效果如下:

 

猜你喜欢

转载自blog.csdn.net/qq_36685744/article/details/85843257