mmdetection(一)安装及训练、测试VOC格式的数据

一、安装

         https://github.com/open-mmlab/mmdetection/blob/master/docs/INSTALL.md

二、训练自己的数据

 1、数据 

     mmdet的默认格式是coco的,这里就以voc格式为例,data下文件夹摆放位置如图

   2、训练

         (1)修改configs文件下的文件

           可先复制一份,然后自己命名一下。比如retinanet_x101_64x4d_fpn_1x.py,修改的部分主要是dataset settings部分,这部分可直接参考

         pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py(如下);还有一部分是修改该文件下的num_classes(类别数+1)

# dataset settings
dataset_type = 'VOCDataset'
data_root = 'data/VOCdevkit/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1000, 600),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    imgs_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type='RepeatDataset',
        times=3,
        dataset=dict(
            type=dataset_type,
            ann_file=[
                data_root + 'VOC2007/ImageSets/Main/trainval.txt'
                
            ],
            img_prefix=[data_root + 'VOC2007/'],
            pipeline=train_pipeline)),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
        img_prefix=data_root + 'VOC2007/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
        img_prefix=data_root + 'VOC2007/',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='mAP')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(policy='step', step=[3])  # actual epoch = 3 * 3 = 9
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
    interval=50,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
# runtime settings
total_epochs = 100  # actual epoch = 4 * 3 = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_r50_fpn_1x_voc0712'
load_from = None
resume_from = None
workflow = [('train', 1)]
View Code

          (2)修改mmdet/datasets/voc.py下classes为自己的类

          (3)训练

      python tools/train.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712_my.py 

 3、测试

      (1)输出mAP

     修改mmdetection/mmdet/core/evaluation   voc_classes()返回自己的类

        python3 tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712_my.py work_dirs/faster_rcnn_r50_fpn_1x_voc0712/latest.pth --eval mAP    --show

    (2)  测试单张图片

        参考 demo/webcam_demo.py,

python demo/img_demo.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712_my.py work_dirs/faster_rcnn_r50_fpn_1x_voc0712/latest.pth demo/2017-09-05-161908.jpg

        

import argparse
import torch

from mmdet.apis import inference_detector, init_detector, show_result


def parse_args():
    parser = argparse.ArgumentParser(description='MMDetection image demo')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('imagepath', help='the path of image to test')
    parser.add_argument('--device', type=int, default=0, help='CUDA device id')
    parser.add_argument(
        '--score-thr', type=float, default=0.5, help='bbox score threshold')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    model = init_detector(
        args.config, args.checkpoint, device=torch.device('cuda', args.device))

    result = inference_detector(model, args.imagepath)
 
    show_result(
        args.imagepath, result, model.CLASSES, score_thr=args.score_thr, wait_time=0)


if __name__ == '__main__':
    main()

参考

https://zhuanlan.zhihu.com/p/101202864

https://blog.csdn.net/laizi_laizi/article/details/104256781

猜你喜欢

转载自www.cnblogs.com/573177885qq/p/12734630.html
今日推荐