用自己的数据训练Faster-RCNN,tensorflow版本(二)

我用的Faster-RCNN是tensorflow版本,fork自githubFaster-RCNN_TF
参考博客http://www.cnblogs.com/CarryPotMan/p/5390336.html

用自己的数据训练Faster-RCNN,tensorflow版本(一)中我们详细介绍了Faster-rcnn_TF中pascal_voc数据的读写接口,接下来介绍一下,如何编写自己的数据读写接口。

3、编写自己的数据读写接口

我们要用自己的数据进行训练,就得编写自己数据的读写接口,下面参考pascal_voc.py来编写。根据用自己的数据训练Faster-RCNN,tensorflow版本(一)中对pascal_voc.py文件的分析,发现,pascal_voc.py用了非常多的路径拼接,很麻烦,我们不用这么麻烦,简单一点就可以。

3.1、介绍一下我自己的训练数据集格式

我主要是从自然图片中检测出文本,因此我只有background 和text两类物体,我并没有像pascal_voc数据集里面一样每个图像用一个xml来标注,先说一下我的数据格式:

所有需要用到的数据我都放在了目录Data/ID_card/下面。
目录Data/ID_card/下面包含2个文件夹,分别是train,test。
先介绍train,目录Data/ID_card/train/里面包含:
1、所有的训练图片
2、gt_ID_card.txt
3、train.txt

我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面,gt_ID_card.txt格式如下:
gt_ID_card.txt

以第一行为例:
ID_card/back_1.jpg: 是图片的名字;
数字1:代表该张图片上只有一个文本(text);
后面的四个数值:分别是文本框左上角和右下角的坐标。我的图片里面只有一行文本,所以只有一组文本框的坐标。

train.txt文件存放的是所有图片的名字,没有后缀,如下图:
train.txt

3.2、编写自己的数据读写接口ID_card.py

主要修改的关键函数就是:def _load_annotation(self)——读取图片gt。

编写自己的数据读写接口ID_card.py,内容如下:

#coding:utf-8
# --------------------------------------------------------
# 
# Written by lisiqi
# --------------------------------------------------------

import datasets
import os
import datasets.imdb
import xml.dom.minidom as minidom
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess

class ID_card(datasets.imdb):
    def __init__(self, image_set, data_path=None):
        datasets.imdb.__init__(self, 'ID_card_' + image_set) #image_set 为train或者val或者trainval或者test。
        self._image_set = image_set # image_set以train为例
        self._data_path = data_path # 数据所在的路径,根据传进来的参数data_path而定。传进来的参数data_path在我这里就是Data/ID_card/
        self._classes = ('__background__','text') #object的类别,只有两类:背景和文本
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) #构成字典{'__background__':'0','text':'1'}
        self._image_ext = '.jpg' #图片后缀
        self._image_index = self._load_image_set_index() #读取train.txt,获取图片名称(该图片名称没有后缀.jpg)
        # Default to roidb handler
        self._roidb_handler = self.gt_roidb #获取图片的gt
        # PASCAL specific config options
        self.config = {'cleanup'  : True,
                       'use_salt' : True,
                       'top_k'    : 2000}

        assert os.path.exists(self._data_path), \ #如果路径Data/ID_card不存在,退出
                'Image Path does not exist: {}'.format(self._data_path)

    def image_path_at(self, i):#获得_image_index 下标为i的图像的路径
        """
        Return the absolute path to image i in the image sequence.
        """
        return self.image_path_from_index(self._image_index[i])

    def image_path_from_index(self, index):#根据_image_index获取图像路径
        """
        Construct an image path from the image's "index" identifier.
        """
        image_path = os.path.join(self._data_path, index, self._image_ext)
        assert os.path.exists(image_path), \
                'Path does not exist: {}'.format(image_path)
        return image_path

    def _load_image_set_index(self):#已做修改
        """
        Load the indexes listed in this dataset's image set file.
        得到图片名称的list。这个list里面是集合self._image_set=train中所有图片的名字(注意,图片名字没有后缀.jpg)
        """
        image_set_file = os.path.join(self._data_path, self._image_set, self._image_set + '.txt') 
        #image_set_file是Data/ID_card/train/train.txt
        #之所以要读这个train.txt文件,是因为train.txt文件里面写的是集合train中所有图片的名字(没有后缀.jpg)
        assert os.path.exists(image_set_file), \
                'Path does not exist: {}'.format(image_set_file)
        with open(image_set_file) as f: #读取train.txt,获取图片名称(没有后缀.jpg)
            image_index = [x.strip() for x in f.readlines()]
        return image_index

    def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.
        读取并返回图片gt的db。这个函数就是将图片的gt加载进来。
        其中,图片的gt信息在gt_ID_card.txt文件中
        并且,图片的gt被提前放在了一个.pkl文件里面。(这个.pkl文件需要我们自己生成,代码就在该函数中)

        This function loads/saves from/to a cache file to speed up future calls.
        之所以会将图片的gt提前放在一个.pkl文件里面,是为了不用每次都再重新读图片的gt,直接加载这个文件就可以了,可以提升速度。
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):#若存在cache file则直接从cache file中读取
            with open(cache_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            print '{} gt roidb loaded from {}'.format(self.name, cache_file)
            return roidb

        gt_roidb = self._load_annotation()  #读入整个gt文件的具体实现
        with open(cache_file, 'wb') as fid:
            cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote gt roidb to {}'.format(cache_file)

        return gt_roidb

    #def selective_search_roidb(self):#在没有使用RPN的时候,是这样提取候选框,fast-rcnn会用到。我直接删除了这个函数,faster-rcnn用不到     
    #def _load_selective_search_roidb(self, gt_roidb):#用不到,删除
    #def selective_search_IJCV_roidb(self):  #用不到,删除      
    #def _load_selective_search_IJCV_roidb(self, gt_roidb): #用不到,删除      

    def _load_annotation(self):
        """
        Load image and bounding boxes info from txt format.
        读取图片的gt的具体实现。
        我把train集合中所有图片的gt,集中放在了一个gt_ID_card.txt文件里面
        gt_ID_card.txt中每行的格式如下:ID_card/train/back_1.jpg 1 147 65 443 361      
        后面的四个数值分别是文本框左上角和右下角的坐标。我的图片里面只有一个文本,所以只有一组文本框的坐标
        """
        gt_roidb = []      
        txtfile = os.path.join(self._data_path, 'gt_ID_card.txt')
        f = open(txtfile)
        split_line = f.readline().strip().split()
        num = 1
        while(split_line):
            num_objs = int(split_line[1])
            boxes = np.zeros((num_objs, 4), dtype=np.uint16)
            gt_classes = np.zeros((num_objs), dtype=np.int32)
            overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
            for i in range(num_objs):
                x1 = float( split_line[2 + i * 4])
                y1 = float (split_line[3 + i * 4])
                x2 = float (split_line[4 + i * 4])
                y2 = float (split_line[5 + i * 4])
                cls = self._class_to_ind['text']
                boxes[i,:] = [x1, y1, x2, y2]
                gt_classes[i] = cls
                overlaps[i,cls] = 1.0

            overlaps = scipy.sparse.csr_matrix(overlaps)
            gt_roidb.append({'boxes' : boxes, 'gt_classes': gt_classes, 'gt_overlaps' : overlaps, 'flipped' : False})
            split_line = f.readline().strip().split()

        f.close()
        return gt_roidb

    #def _write_voc_results_file(self, all_boxes):#没用,删掉        
    #def _do_matlab_eval(self, comp_id, output_dir='output'): #没用,删掉       
    #def evaluate_detections(self, all_boxes, output_dir):# 没用,删掉

    def competition_mode(self, on):
        if on:
            self.config['use_salt'] = False
            self.config['cleanup'] = False
        else:
            self.config['use_salt'] = True
            self.config['cleanup'] = True

if __name__ == '__main__':
    import datasets.ID_card #作了修改
    d = datasets.ID_card('train', 'Data/ID_card/')#datasets.ID_card()在factory.py中用到了,
    res = d.roidb
    from IPython import embed; embed()

到这里,就完成了整个的读取接口的改写,主要在gt的读取。
除了要修改数据读写接口,还有一些文件需要修改。

3.3、修改factory.py

建议先将原来的factory.py复制成factory_bak.py作为备份,然后再在factory.py上进行修改。

修改后的factory.py如下:

"""Factory method for easily getting imdbs by name."""

import datasets.ID_card as ID_card #首先在文件头import把pascal_voc改成ID_card

__sets = {}
image_set = 'train' 
data_path = '/data/home/lisiqi/Data/ID_card' #自己数据的路径

def get_imdb(name): # 当网络训练时会调用factory里面的get_imdb方法获得相应的imdb
    """Get an imdb (image database) by name."""
    __sets[name] = (lambda image_set=image_set, data_path=data_path: ID_card.ID_card(image_set,data_path)) #ID_card.ID_card()的意思是调用文件ID_card.py中的类ID_card
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: {}'.format(name))
    return __sets[name]()

def list_imdbs():
    """List all registered imdbs."""
    return __sets.keys()

3.4、修改模型文件配置

3.4.1、修改config.py

工程Faster-RCNN_TF模型的参数都在文件Faster-RCNN_TF/lib/fast-rcnn/config.py中被定义。

将config.py中有如下参数的地方,按照下面的进行修改:

# Images to use per minibatch
__C.TRAIN.IMS_PER_BATCH = 1 #每次输入到faster-rcnn网络中的图片数量是1张

# Iterations between snapshots
__C.TRAIN.SNAPSHOT_ITERS = 1000  # 训练的时候,每1000步保存一次模型。这个可以自己随意设置

__C.TRAIN.SNAPSHOT_PREFIX = 'VGGnet_faster_rcnn' #模型在保存时的名字
# Use RPN to detect objects
__C.TRAIN.HAS_RPN = True #是否使用RPN。True代表使用RPN

3.4.2、修改VGG_train.py和VGG_test.py

要想启动Faster RCNN网络训练,需要用到文件Faster-RCNN_TF/lib/networks/VGGnet_train.py。
因为我的任务是检测自然图像中的文本,所以我的检测目标物是text,那么我的类别就有两个类别即 background 和 text。

VGGnet_train.py需要修改的地方如下:

把n_classes 从原来的21类(20类+背景) ,改成 2类(人+背景),其它不用变。
这里写图片描述

3.5、启动Faster RCNN网络训练

网络的训练文件是Faster_RCNN-TF/tools/train_net.py,内容如下:


"""Train a Fast R-CNN network on a region of interest database."""

import _init_paths
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg,cfg_from_file, cfg_from_list, get_output_dir
from datasets.factory import get_imdb
from networks.factory import get_network
import argparse
import pprint
import numpy as np
import sys
import os  #新增加的
import pdb #打断点时,会用到

def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
    parser.add_argument('--device', dest='device', help='device to use',
                        default='cpu', type=str)
    parser.add_argument('--device_id', dest='device_id', help='device id to use',
                        default=0, type=int)
    parser.add_argument('--solver', dest='solver',
                        help='solver prototxt',
                        default=None, type=str)
    parser.add_argument('--iters', dest='max_iters',
                        help='number of iterations to train',
                        default=70000, type=int)
    parser.add_argument('--weights', dest='pretrained_model',
                        help='initialize with pretrained model weights',
                        default=None, type=str)
    parser.add_argument('--cfg', dest='cfg_file',
                        help='optional config file',
                        default=None, type=str)
    parser.add_argument('--imdb', dest='imdb_name',
                        help='dataset to train on',
                        default='kitti_train', type=str)
    parser.add_argument('--rand', dest='randomize',
                        help='randomize (do not use a fixed seed)',
                        action='store_true')
    parser.add_argument('--network', dest='network_name',
                        help='name of the network',
                        default=None, type=str)
    parser.add_argument('--set', dest='set_cfgs',
                        help='set config keys', default=None,
                        nargs=argparse.REMAINDER)

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(1)

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_args()
    print('Called with args:')
    print(args)

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    print('Using config:')
    pprint.pprint(cfg)
    if not args.randomize:
        # fix the random seeds (numpy and caffe) for reproducibility
        np.random.seed(cfg.RNG_SEED)
    imdb = get_imdb(args.imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    roidb = get_training_roidb(imdb)

    output_dir = get_output_dir(imdb, None)
    print 'Output will be saved to `{:s}`'.format(output_dir)

    # 设置cpu或者gpu的id
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device_id)
    device_name = '/{}:{:d}'.format(args.device,args.device_id)
    print device_name

    network = get_network(args.network_name)
    print 'Use network `{:s}` in training'.format(args.network_name)
    #pdb.set_trace() #在此处设置一个断点
    train_net(network, imdb, roidb, output_dir,
              pretrained_model=args.pretrained_model,
              max_iters=args.max_iters)

在终端,启动网络训练。在路径Faster-RCNN_TF下,输入:

python ./tools/train_net.py --device gpu --device_id 3 --solver VGG_CNN_M_1024 --weight ./data/pretrain_model/VGG_imagenet.npy --imdb ID_card_train --network IDcard_train

参数解释:
train_net.py: 是网络的训练文件
—device :代表选用cpu还是gpu
—device_id: 代表机器上的cpu或者gpu的编号
—solver: 模型的配置文件,这个参数就不要进行修改了,固定就是VGG_CNN_M_1024
—weight: 初始化的权重文件,这里用的是Imagenet上预训练好的模型VGG_imagenet.npy,存放在目录./data/pretrain_model下
—imdb: 训练的数据库名字,这个名字可以自己随便起
—network: 代表选择训练网络还是测试网络,这个参数的形式是固定的,必须是IDcard_train的形式,前半部分IDcard可以随便起(但是不能有下划线),后半部分必须是_train

训练完成之后的模型默认保存在了目录./output/default/ID_card_train下。我们会发现,该目录下会出现以下文件:

这里写图片描述

TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。模型会保存在后缀为.ckpt的文件中。保存后,在目录./output/default/ID_card_train下会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。参考自TensorFlow学习笔记(8)–网络模型的保存和读取

checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.

model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。

model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。

3.6、测试Faster RCNN网络训练的模型

参考./tools/demo.py,写自己的demo.py。

由于我所使用的服务器中无法使用plot,所以我将检测的坐标结果直接画在了测试图片上,并且将图片保存在了目录./results下。

修改后的demo.py内容如下:

import _init_paths
import tensorflow as tf
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import os, sys, cv2
import argparse
from networks.factory import get_network
import glob

import os
#os.environ['CUDA_VISIBLE_DEVICES']='3'

import pdb #设断点时,使用的


CLASSES = ('__background__', 'text')物体类别


def vis_detections(im, class_name, dets, image_name, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return
    im = im.copy()

    for i in inds:
        bbox = dets[i, :4] #检测图片的坐标
        score = dets[i, -1]

        cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 3) #将坐标直接画在了检测图片上

    cv2.imwrite('./results/' + image_name, im) #将带有框的检测图片保存在目录./results下
    print class_name

def demo(sess, net, image_name_path):
    """Detect object classes in an image using pre-computed object proposals."""

    # load images
    im = cv2.imread(image_name_path)
    im_name = os.path.basename(image_name_path)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(sess, net, im)
    print boxes
    timer.toc()
    print ('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0])

    # Visualize detections for each class
    #im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    # ax.imshow(im, aspect='equal')

    CONF_THRESH = 0.5
    NMS_THRESH = 0.3
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        #vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)
        vis_detections(im, cls, dets, im_name, thresh=CONF_THRESH)
def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Faster R-CNN demo')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
                        default=0, type=int)
    parser.add_argument('--cpu', dest='cpu_mode',
                        help='Use CPU mode (overrides --gpu)',
                        action='store_true')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
                        default='VGGnet_test')
    parser.add_argument('--model', dest='model', help='Model path',
                        default='/data/home/lisiqi/Faster-RCNN_TF_original/weight/VGGnet_fast_rcnn_iter_70000.ckpt')
    parser.add_argument('--results_dir', dest='results_dir', help='Results director',
                        default=' ')

    args = parser.parse_args()

    return args
if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals

    args = parse_args()

    # GPU id(设置GPU的编号)
    os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu_id)

    if args.model == ' ':
        raise IOError(('Error: Model not found.\n'))

    if args.results_dir == ' ':
        raise IOError(('Error: Result director not found.\n'))   

    # init session
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    # load network
    net = get_network(args.demo_net)
    # load model
    #saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
    #pdb.set_trace()
    saver = tf.train.Saver()
    saver.restore(sess, args.model) #加载训练好的模型,名称就写到.ckpt就行,例如VGGnet_faster_rcnn_iter_1000.ckpt

    #sess.run(tf.initialize_all_variables())

    print '\n\nLoaded network {:s}'.format(args.model)

    # Warmup on a dummy image
    im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
    for i in xrange(2):
        _, _= im_detect(sess, net, im)

    # load images
    im_file_dir = cfg.DATA_DIR + '/demo/'
    im_names_path = glob.glob(im_file_dir + '*.jpg')
    #pdb.set_trace()

    for im_name_path in im_names_path:
        #im_name = os.path.basename(im_name_path)
        print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
        print 'Demo for data/demo/{}'.format(im_name_path)
        #pdb.set_trace()
        demo(sess, net, im_name_path)

    print 'results_dir:{}'.format(args.results_dir)
    #plt.show()

在终端,路径Faster-RCNN_TF下,输入:

 python ./tools/demo.py --gpu 3 --model ./output/default/ID_card_train/VGGnet_faster_rcnn_iter_1000.ckpt --results ./results/

参数解释:
demo.py: 测试图片的文件
—gpu : 代表机器上gpu的编号(直接就默认使用gpu,没有cpu选项)
—model: 网络训练好的模型
—results: 保存结果的路径

猜你喜欢

转载自blog.csdn.net/lanyuelvyun/article/details/78094003