tf-faster-rcnn/tools/train_net.py
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Zheqi He, Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import _init_paths
from model.train_val import get_training_roidb, train_net
from model.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_output_tb_dir
from datasets.factory import get_imdb
import datasets.imdb
import argparse
import pprint
import numpy as np
import sys
import tensorflow as tf
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
from nets.mobilenet_v1 import mobilenetv1
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default=None, type=str)
parser.add_argument('--weight', dest='weight',
help='initialize with pretrained model weights',
type=str)
parser.add_argument('--imdb', dest='imdb_name',
help='dataset to train on',
default='voc_2007_trainval', type=str)
parser.add_argument('--imdbval', dest='imdbval_name',
help='dataset to validate on',
default='voc_2007_test', type=str)
parser.add_argument('--iters', dest='max_iters',
help='number of iterations to train',
default=70000, type=int)
parser.add_argument('--tag', dest='tag',
help='tag of the model',
default=None, type=str)
parser.add_argument('--net', dest='net',
help='vgg16, res50, res101, res152, mobile',
default='res50', type=str)
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
def combined_roidb(imdb_names):
"""
Combine multiple roidbs
"""
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print('Loaded dataset `{:s}` for training'.format(imdb.name))
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
tmp = get_imdb(imdb_names.split('+')[1])
imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print('Using config:')
pprint.pprint(cfg)
np.random.seed(cfg.RNG_SEED)
# train set训练集
imdb, roidb = combined_roidb(args.imdb_name)
print('{:d} roidb entries'.format(len(roidb)))
# output directory where the models are saved输出路径
output_dir = get_output_dir(imdb, args.tag) #config.py中
print('Output will be saved to `{:s}`'.format(output_dir))
# tensorboard directory where the summaries are saved during training
#在训练过程中将摘要保存在tensorboard目录中
tb_dir = get_output_tb_dir(imdb, args.tag)#args.tag模型的标记,config.py中get_output_tb_dir()函数
print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))
# also add the validation set, but with no flipping images
orgflip = cfg.TRAIN.USE_FLIPPED #图片不翻转
cfg.TRAIN.USE_FLIPPED = False
_, valroidb = combined_roidb(args.imdbval_name) #添加验证集
print('{:d} validation roidb entries'.format(len(valroidb))) #验证集的大小
cfg.TRAIN.USE_FLIPPED = orgflip#验证集不翻转
# load network加载网络
if args.net == 'vgg16':
net = vgg16()
elif args.net == 'res50':
net = resnetv1(num_layers=50)
elif args.net == 'res101':
net = resnetv1(num_layers=101)
elif args.net == 'res152':
net = resnetv1(num_layers=152)
elif args.net == 'mobile':
net = mobilenetv1()
else:
raise NotImplementedError
train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
pretrained_model=args.weight,
max_iters=args.max_iters)#训练网络
factory.py中get_imdb()函数:在get_roidb函数中调用factory中的get_imdb根据__sets[name]中的key(一个lambda表达式)转到pascol_voc类
def get_imdb(name):
"""Get an imdb (image database) by name.通过名称获取IMDB"""
if name not in __sets:
raise KeyError('Unknown dataset: {}'.format(name))
return __sets[name]()
config.py中get_output_dir()函数
def get_output_dir(imdb, weights_filename):
"""Return the directory where experimental artifacts (实验结果)are placed.
If the directory does not exist, it is created.
A canonical path is built using the name from an imdb and a network
(if not None).
"""
#创建文件夹存放结果,例如:/home/ubuntu/tf-faster-rcnn-master/output/res101/voc_2007_trainval/default
outdir = osp.abspath(osp.join(__C.ROOT_DIR, 'output', __C.EXP_DIR, imdb.name))
if weights_filename is None:
weights_filename = 'default'
outdir = osp.join(outdir, weights_filename)
if not os.path.exists(outdir):
os.makedirs(outdir)
return outdir
config.py中get_output_tb_dir()函数
def get_output_tb_dir(imdb, weights_filename):
"""Return the directory where tensorflow summaries are placed.
If the directory does not exist, it is created.
A canonical权威的 path is built using the name from an imdb and a network
(if not None).
"""
outdir = osp.abspath(osp.join(__C.ROOT_DIR, 'tensorboard', __C.EXP_DIR, imdb.name))#获得保存文件的绝对路径
if weights_filename is None:#设置权重文件夹名称,例如/home/ubuntu/tf-faster-rcnn-master/tensorboard/res101/voc_2007_trainval/default
weights_filename = 'default'
outdir = osp.join(outdir, weights_filename)#获得路径,例如/home/ubuntu/tf-faster-rcnn-master/tensorboard/res101/voc_2007_trainval/default
if not os.path.exists(outdir):#判断文件是否存在,不存在则创建一个文件
os.makedirs(outdir)
return outdir