Caltech数据读取接口

1. 引言

上一篇我们详述了如何将Caltech数据集转化成VOC格式的数据集,在使用之前,我们需要编写读取数据集的接口

2. 先上结果

下图是进行了50000次训练后,在测试集上进行训练的结果。我这个结果偏低,因为有person和people的区分在内部会导致一些测试出错。
这里写图片描述
下图是从网络上随意下载的一副图像,前面的那个那么明显的人都没有检测到!!!不明白为什么,难道是穿着问题?
这里写图片描述
虽然上述两图反映出检测的准确率不是很理想,但是说明整个训练流程没有问题,我们将在后面的过程对网络的参数进行优化。

3. 训练自己的数据集

假设这里我们已经按照前一篇文章,制作了数据集。以Caltech行人数据集为例,我们需要编写数据集的读取接口,并且修改训练脚本和配置文件等。

3.1 编写数据集的读取接口

1.修改lib/datasets/pascal_voc.py文件
复制一份pascal_voc文件,重命名为caltech.py,在这份文件上进行修改。
__init__(self,image_set,year,devkit_path=None)
该函数是数据集的构造函数,主要包含了该数据集的一些属性。
修改后:

#删除掉year,因为我们的数据集也没有什么年份
def __init__(self, image_set,devkit_path=None):   
        imdb.__init__(self, image_set)
        #image_set指的是train.txt或者是test.txt        
        self._image_set = image_set
        #这个路径是数据集的整个大路径,比如我们的项目是/home/user2/chen_guang_hao/PeDetect/smallcorgi/Faster-RCNN_TF-master/data/VOCdevkit2007,在这个路径下我放置了Caltech数据集
        self._devkit_path = devkit_path
        #参考前一句话
        self._data_path = os.path.join(self._devkit_path, 'Caltech')
        #注意,这里是修改位置,在这个位置需要修改类别
        #因为Caletch的标准中有person,people等好几种人,所以我就取了people和person两种类别,正常的话应该就是background和person
        self._classes = ('__background__','person','people')               
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        #self._roidb_handler = self.selective_search_roidb
        self._roidb_handler = self.gt_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup'     : True,
                       'use_salt'    : True,
                       #这里use diff的值修改成true了,忘记为什么了
                       'use_diff'    : True,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2}

        assert os.path.exists(self._devkit_path), \
                'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Caltech Data Path does not exist: {}'.format(self._data_path)

_load_pascal_annotation(self, index)根据id号,导入图片的标注文件
修改后:

def _load_pascal_annotation(self, index):
        """
        根据index,读取Caltech数据集某一样图片的标注xml文件
        """
        filename = os.path.join(self._data_path, 'Annotations', index + '.xml')

        #解析xml标注文件
        tree = ET.parse(filename)
        #读取所有的object,即person
        objs = tree.findall('object')

        if not self.config['use_diff']:
            # Exclude the samples labeled as difficult
            non_diff_objs = [
                obj for obj in objs if int(obj.find('difficult').text) == 0]
            # if len(non_diff_objs) != len(objs):
            #     print 'Removed {} difficult objects'.format(
            #         len(objs) - len(non_diff_objs))
            objs = non_diff_objs

        #行人的数量
        num_objs = len(objs)
        #定义该图片的标注框数组 n*4
        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        # "Seg" area for pascal is just the box area
        seg_areas = np.zeros((num_objs), dtype=np.float32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            #这个位置删除了之前的减1
            x1 = float(bbox.find('xmin').text) 
            y1 = float(bbox.find('ymin').text) 
            x2 = float(bbox.find('xmax').text) 
            y2 = float(bbox.find('ymax').text) 
            cls = self._class_to_ind[obj.find('name').text.lower().strip()]
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            overlaps[ix, cls] = 1.0
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

        overlaps = scipy.sparse.csr_matrix(overlaps)

        return {'boxes' : boxes,
                'gt_classes': gt_classes,
                'gt_overlaps' : overlaps,
                'flipped' : False,
                'seg_areas' : seg_areas}

最后的main函数,把路径修改一下。

if __name__ == '__main__':
    from datasets.caltech import caltech # 导入caltech包
    d = caltech('trainval', '/home/user2/chen_guang_hao/PeDetect/smallcorgi/Faster-RCNN_TF-master/data/VOCdevkit2007')
    res = d.roidb
    from IPython import embed; embed()

2.修改lib/datasets/factory.py文件

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Factory method for easily getting imdbs by name."""

__sets = {}

import datasets.pascal_voc
import datasets.imagenet3d
import datasets.kitti
import datasets.kitti_tracking
import numpy as np
#导入caltech文件
from datasets.caltech import caltech

def _selective_search_IJCV_top_k(split, year, top_k):
    """Return an imdb that uses the top k proposals from the selective search
    IJCV code.
    """
    imdb = datasets.pascal_voc(split, year)
    imdb.roidb_handler = imdb.selective_search_IJCV_roidb
    imdb.config['top_k'] = top_k
    return imdb

# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year:
                datasets.pascal_voc(split, year))
"""
# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode
# but only returning the first k boxes
for top_k in np.arange(1000, 11000, 1000):
    for year in ['2007', '2012']:
        for split in ['train', 'val', 'trainval', 'test']:
            name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k)
            __sets[name] = (lambda split=split, year=year, top_k=top_k:
                    _selective_search_IJCV_top_k(split, year, top_k))
"""

# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        print name
        __sets[name] = (lambda split=split, year=year:
                datasets.pascal_voc(split, year))

# KITTI dataset
for split in ['train', 'val', 'trainval', 'test']:
    name = 'kitti_{}'.format(split)
    print name
    __sets[name] = (lambda split=split:
            datasets.kitti(split))

# Set up coco_2014_<split>
for year in ['2014']:
    for split in ['train', 'val', 'minival', 'valminusminival']:
        name = 'coco_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: coco(split, year))

# Set up coco_2015_<split>
for year in ['2015']:
    for split in ['test', 'test-dev']:
        name = 'coco_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: coco(split, year))

# NTHU dataset
for split in ['71', '370']:
    name = 'nthu_{}'.format(split)
    print name
    __sets[name] = (lambda split=split:
            datasets.nthu(split))

#增加我们的caltech数据集
devkit = '/home/user2/chen_guang_hao/PeDetect/smallcorgi/Faster-RCNN_TF-master/data/VOCdevkit2007'
for split in ['trainval', 'test']:
    name = 'caltech_{}'.format(split)
    __sets[name] = (lambda imageset=split,
            devkit=devkit:caltech(imageset,devkit))

def get_imdb(name):
    """Get an imdb (image database) by name."""
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: {}'.format(name))
    return __sets[name]()

def list_imdbs():
    """List all registered imdbs."""
    return __sets.keys()

3.修改lib/network/VGGnet_train.py和VGGnet_test.py文件
修改n_classes的值为3

3.2修改配置文件lib/fast_rcnn/cfgs.py

3.3修改experiments/faster_rcbb_end2end.sh

#!/bin/bash
# Usage:
# ./experiments/scripts/faster_rcnn_end2end.sh GPU NET DATASET [options args to {train,test}_net.py]
# DATASET is either pascal_voc or coco.
#
# Example:
# ./experiments/scripts/faster_rcnn_end2end.sh 0 VGG_CNN_M_1024 pascal_voc \
#   --set EXP_DIR foobar RNG_SEED 42 TRAIN.SCALES "[400, 500, 600, 700]"

set -x
set -e

export PYTHONUNBUFFERED="True"

DEV=$1
DEV_ID=$2
NET=$3
DATASET=$4

array=( $@ )
len=${#array[@]}
EXTRA_ARGS=${array[@]:4:$len}
EXTRA_ARGS_SLUG=${EXTRA_ARGS// /_}

case $DATASET in
  pascal_voc)
    TRAIN_IMDB="voc_2007_trainval"
    TEST_IMDB="voc_2007_test"
    PT_DIR="pascal_voc"
    ITERS=70000
    ;;
    #增加下面的内容
  caltech)
    TRAIN_IMDB='caltech_trainval'
    TEST_IMDB="caltech_test"
    PT_DIR="caltech"
    ITERS=500000
    ;;
  coco)
    # This is a very long and slow training schedule
    # You can probably use fewer iterations and reduce the
    # time to the LR drop (set in the solver to 350,000 iterations).
    TRAIN_IMDB="coco_2014_train"
    TEST_IMDB="coco_2014_minival"
    PT_DIR="coco"
    ITERS=490000
    ;;
  *)
    echo "No dataset given"
    exit
    ;;
esac

LOG="experiments/logs/faster_rcnn_end2end_${NET}_${EXTRA_ARGS_SLUG}.txt.`date +'%Y-%m-%d_%H-%M-%S'`"
exec &> >(tee -a "$LOG")
echo Logging output to "$LOG"

time python ./tools/train_net.py --device ${DEV} --device_id ${DEV_ID} \
  --weights data/pretrain_model/VGG_imagenet.npy \
  --imdb ${TRAIN_IMDB} \
  --iters ${ITERS} \
  --cfg experiments/cfgs/faster_rcnn_end2end.yml \
  --network VGGnet_train \
  ${EXTRA_ARGS}

set +x
NET_FINAL=`grep -B 1 "done solving" ${LOG} | grep "Wrote snapshot" | awk '{print $4}'`
set -x

time python ./tools/test_net.py --device ${DEV} --device_id ${DEV_ID} \
  --weights ${NET_FINAL} \
  --imdb ${TEST_IMDB} \
  --cfg experiments/cfgs/faster_rcnn_end2end.yml \
  --network VGGnet_test \
  ${EXTRA_ARGS}

猜你喜欢

转载自blog.csdn.net/qq_33297776/article/details/80039104