mmseg底层代码分析及修改

本人采用voc格式数据集对mmseg底层代码进行分析,本文适合对mmseg使用流程比较熟悉的同学食用

数据增广部分代码分析

1.数据增广使用:

mmsegmentation/tools/convert_datasets/voc_aug.py中修改AUG_LEN参数为对应训练集图片数量,并在config文件声明如下:

train=dict(
ann_dir=[‘SegmentationClass’, ‘SegmentationClassAug’],
split=[
‘ImageSets/Segmentation/train.txt’,
‘ImageSets/Segmentation/aug.txt’
]))

2.代码解析

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial

import mmcv
import numpy as np
from PIL import Image
from scipy.io import loadmat

AUG_LEN = 10582
#对应训练集图片数量

"""将.mat文件转换为.png文件"""
def convert_mat(mat_file, in_dir, out_dir):
    data = loadmat(osp.join(in_dir, mat_file))
    mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
    seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
    Image.fromarray(mask).save(seg_filename, 'PNG')

#生成
def generate_aug_list(merged_list, excluded_list):
    return list(set(merged_list) - set(excluded_list))


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert PASCAL VOC annotations to mmsegmentation format')
    parser.add_argument('devkit_path', help='pascal voc devkit path')
    parser.add_argument('aug_path', help='pascal voc aug path')
    parser.add_argument('-o', '--out_dir', help='output path')
    parser.add_argument(
        '--nproc', default=1, type=int, help='number of process')
    args = parser.parse_args()
    return args


def main():
    """参数初始化"""
    args = parse_args()
    devkit_path = args.devkit_path
    #VOC2012数据地址
    aug_path = args.aug_path
    #数据增广地址
    nproc = args.nproc
    #并行数
    if args.out_dir is None:
        out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
    else:
        out_dir = args.out_dir
    #数据增广输出地址
    mmcv.mkdir_or_exist(out_dir)
    #创建数据增广输出地址
    in_dir = osp.join(aug_path, 'dataset', 'cls')
    #拼接数据增广地址

    mmcv.track_parallel_progress(
        partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
        list(mmcv.scandir(in_dir, suffix='.mat')),
        nproc=nproc)
    # 并行任务的跟踪进度

    full_aug_list = []
    with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
        full_aug_list += [line.strip() for line in f]
    with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
        full_aug_list += [line.strip() for line in f]
    #从增广处获取训练验证标签名称列表
    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'train.txt')) as f:
        ori_train_list = [line.strip() for line in f]
    #获取训练集标签名称列表
    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'val.txt')) as f:
        val_list = [line.strip() for line in f]
    # 获取验证集标签名称列表
    aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
                                       val_list)
    assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
        AUG_LEN)
    #生成增广的训练标签列表:set(原训练标签列表+增广数据集列表)-set(验证集标签列表)

    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'trainaug.txt'), 'w') as f:
        f.writelines(line + '\n' for line in aug_train_list)
    #增广训练数据标签列表写入trainaug.txt
    aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
    assert len(aug_list) == AUG_LEN - len(
        ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
                                                     len(ori_train_list))
    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
            'w') as f:
        f.writelines(line + '\n' for line in aug_list)
    #获取并写入增广数据标签列表
    print('Done!')


if __name__ == '__main__':
    main()

本文参考内容:mmseg地址

猜你喜欢

转载自blog.csdn.net/qq_44840741/article/details/125840880
今日推荐