1. 引言
在上一篇我们详述了如何将Caltech数据集转化成VOC格式的数据集,在使用之前,我们需要编写读取数据集的接口
2. 先上结果
下图是进行了50000次训练后,在测试集上进行训练的结果。我这个结果偏低,因为有person和people的区分在内部会导致一些测试出错。
下图是从网络上随意下载的一副图像,前面的那个那么明显的人都没有检测到!!!不明白为什么,难道是穿着问题?
虽然上述两图反映出检测的准确率不是很理想,但是说明整个训练流程没有问题,我们将在后面的过程对网络的参数进行优化。
3. 训练自己的数据集
假设这里我们已经按照前一篇文章,制作了数据集。以Caltech行人数据集为例,我们需要编写数据集的读取接口,并且修改训练脚本和配置文件等。
3.1 编写数据集的读取接口
1.修改lib/datasets/pascal_voc.py文件
复制一份pascal_voc文件,重命名为caltech.py,在这份文件上进行修改。
__init__(self,image_set,year,devkit_path=None)
该函数是数据集的构造函数,主要包含了该数据集的一些属性。
修改后:
#删除掉year,因为我们的数据集也没有什么年份
def __init__(self, image_set,devkit_path=None):
imdb.__init__(self, image_set)
#image_set指的是train.txt或者是test.txt
self._image_set = image_set
#这个路径是数据集的整个大路径,比如我们的项目是/home/user2/chen_guang_hao/PeDetect/smallcorgi/Faster-RCNN_TF-master/data/VOCdevkit2007,在这个路径下我放置了Caltech数据集
self._devkit_path = devkit_path
#参考前一句话
self._data_path = os.path.join(self._devkit_path, 'Caltech')
#注意,这里是修改位置,在这个位置需要修改类别
#因为Caletch的标准中有person,people等好几种人,所以我就取了people和person两种类别,正常的话应该就是background和person
self._classes = ('__background__','person','people')
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
self._image_ext = '.jpg'
self._image_index = self._load_image_set_index()
# Default to roidb handler
#self._roidb_handler = self.selective_search_roidb
self._roidb_handler = self.gt_roidb
self._salt = str(uuid.uuid4())
self._comp_id = 'comp4'
# PASCAL specific config options
self.config = {'cleanup' : True,
'use_salt' : True,
#这里use diff的值修改成true了,忘记为什么了
'use_diff' : True,
'matlab_eval' : False,
'rpn_file' : None,
'min_size' : 2}
assert os.path.exists(self._devkit_path), \
'VOCdevkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), \
'Caltech Data Path does not exist: {}'.format(self._data_path)
_load_pascal_annotation(self, index)根据id号,导入图片的标注文件
修改后:
def _load_pascal_annotation(self, index):
"""
根据index,读取Caltech数据集某一样图片的标注xml文件
"""
filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
#解析xml标注文件
tree = ET.parse(filename)
#读取所有的object,即person
objs = tree.findall('object')
if not self.config['use_diff']:
# Exclude the samples labeled as difficult
non_diff_objs = [
obj for obj in objs if int(obj.find('difficult').text) == 0]
# if len(non_diff_objs) != len(objs):
# print 'Removed {} difficult objects'.format(
# len(objs) - len(non_diff_objs))
objs = non_diff_objs
#行人的数量
num_objs = len(objs)
#定义该图片的标注框数组 n*4
boxes = np.zeros((num_objs, 4), dtype=np.uint16)
gt_classes = np.zeros((num_objs), dtype=np.int32)
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
# "Seg" area for pascal is just the box area
seg_areas = np.zeros((num_objs), dtype=np.float32)
# Load object bounding boxes into a data frame.
for ix, obj in enumerate(objs):
bbox = obj.find('bndbox')
# Make pixel indexes 0-based
#这个位置删除了之前的减1
x1 = float(bbox.find('xmin').text)
y1 = float(bbox.find('ymin').text)
x2 = float(bbox.find('xmax').text)
y2 = float(bbox.find('ymax').text)
cls = self._class_to_ind[obj.find('name').text.lower().strip()]
boxes[ix, :] = [x1, y1, x2, y2]
gt_classes[ix] = cls
overlaps[ix, cls] = 1.0
seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
overlaps = scipy.sparse.csr_matrix(overlaps)
return {'boxes' : boxes,
'gt_classes': gt_classes,
'gt_overlaps' : overlaps,
'flipped' : False,
'seg_areas' : seg_areas}
最后的main函数,把路径修改一下。
if __name__ == '__main__':
from datasets.caltech import caltech # 导入caltech包
d = caltech('trainval', '/home/user2/chen_guang_hao/PeDetect/smallcorgi/Faster-RCNN_TF-master/data/VOCdevkit2007')
res = d.roidb
from IPython import embed; embed()
2.修改lib/datasets/factory.py文件
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""Factory method for easily getting imdbs by name."""
__sets = {}
import datasets.pascal_voc
import datasets.imagenet3d
import datasets.kitti
import datasets.kitti_tracking
import numpy as np
#导入caltech文件
from datasets.caltech import caltech
def _selective_search_IJCV_top_k(split, year, top_k):
"""Return an imdb that uses the top k proposals from the selective search
IJCV code.
"""
imdb = datasets.pascal_voc(split, year)
imdb.roidb_handler = imdb.selective_search_IJCV_roidb
imdb.config['top_k'] = top_k
return imdb
# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year:
datasets.pascal_voc(split, year))
"""
# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode
# but only returning the first k boxes
for top_k in np.arange(1000, 11000, 1000):
for year in ['2007', '2012']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k)
__sets[name] = (lambda split=split, year=year, top_k=top_k:
_selective_search_IJCV_top_k(split, year, top_k))
"""
# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007']:
for split in ['train', 'val', 'trainval', 'test']:
name = 'voc_{}_{}'.format(year, split)
print name
__sets[name] = (lambda split=split, year=year:
datasets.pascal_voc(split, year))
# KITTI dataset
for split in ['train', 'val', 'trainval', 'test']:
name = 'kitti_{}'.format(split)
print name
__sets[name] = (lambda split=split:
datasets.kitti(split))
# Set up coco_2014_<split>
for year in ['2014']:
for split in ['train', 'val', 'minival', 'valminusminival']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# Set up coco_2015_<split>
for year in ['2015']:
for split in ['test', 'test-dev']:
name = 'coco_{}_{}'.format(year, split)
__sets[name] = (lambda split=split, year=year: coco(split, year))
# NTHU dataset
for split in ['71', '370']:
name = 'nthu_{}'.format(split)
print name
__sets[name] = (lambda split=split:
datasets.nthu(split))
#增加我们的caltech数据集
devkit = '/home/user2/chen_guang_hao/PeDetect/smallcorgi/Faster-RCNN_TF-master/data/VOCdevkit2007'
for split in ['trainval', 'test']:
name = 'caltech_{}'.format(split)
__sets[name] = (lambda imageset=split,
devkit=devkit:caltech(imageset,devkit))
def get_imdb(name):
"""Get an imdb (image database) by name."""
if not __sets.has_key(name):
raise KeyError('Unknown dataset: {}'.format(name))
return __sets[name]()
def list_imdbs():
"""List all registered imdbs."""
return __sets.keys()
3.修改lib/network/VGGnet_train.py和VGGnet_test.py文件
修改n_classes的值为3
3.2修改配置文件lib/fast_rcnn/cfgs.py
3.3修改experiments/faster_rcbb_end2end.sh
#!/bin/bash
# Usage:
# ./experiments/scripts/faster_rcnn_end2end.sh GPU NET DATASET [options args to {train,test}_net.py]
# DATASET is either pascal_voc or coco.
#
# Example:
# ./experiments/scripts/faster_rcnn_end2end.sh 0 VGG_CNN_M_1024 pascal_voc \
# --set EXP_DIR foobar RNG_SEED 42 TRAIN.SCALES "[400, 500, 600, 700]"
set -x
set -e
export PYTHONUNBUFFERED="True"
DEV=$1
DEV_ID=$2
NET=$3
DATASET=$4
array=( $@ )
len=${#array[@]}
EXTRA_ARGS=${array[@]:4:$len}
EXTRA_ARGS_SLUG=${EXTRA_ARGS// /_}
case $DATASET in
pascal_voc)
TRAIN_IMDB="voc_2007_trainval"
TEST_IMDB="voc_2007_test"
PT_DIR="pascal_voc"
ITERS=70000
;;
#增加下面的内容
caltech)
TRAIN_IMDB='caltech_trainval'
TEST_IMDB="caltech_test"
PT_DIR="caltech"
ITERS=500000
;;
coco)
# This is a very long and slow training schedule
# You can probably use fewer iterations and reduce the
# time to the LR drop (set in the solver to 350,000 iterations).
TRAIN_IMDB="coco_2014_train"
TEST_IMDB="coco_2014_minival"
PT_DIR="coco"
ITERS=490000
;;
*)
echo "No dataset given"
exit
;;
esac
LOG="experiments/logs/faster_rcnn_end2end_${NET}_${EXTRA_ARGS_SLUG}.txt.`date +'%Y-%m-%d_%H-%M-%S'`"
exec &> >(tee -a "$LOG")
echo Logging output to "$LOG"
time python ./tools/train_net.py --device ${DEV} --device_id ${DEV_ID} \
--weights data/pretrain_model/VGG_imagenet.npy \
--imdb ${TRAIN_IMDB} \
--iters ${ITERS} \
--cfg experiments/cfgs/faster_rcnn_end2end.yml \
--network VGGnet_train \
${EXTRA_ARGS}
set +x
NET_FINAL=`grep -B 1 "done solving" ${LOG} | grep "Wrote snapshot" | awk '{print $4}'`
set -x
time python ./tools/test_net.py --device ${DEV} --device_id ${DEV_ID} \
--weights ${NET_FINAL} \
--imdb ${TEST_IMDB} \
--cfg experiments/cfgs/faster_rcnn_end2end.yml \
--network VGGnet_test \
${EXTRA_ARGS}