VOC格式或者COCO格式检测数据集提取特定类

序言

有时候我们需要从已经标记好的数据集中提取某些类进行训练,以常见的COCO数据集和VOC数据集格式的标注为例,本文提供了两种数据集格式的特定类提取方法,网上也有很多类似的内容,权当总结记录,以后用到时方便找出。

一、COCO格式数据集提取特定类

# COCO数据集提取某个类或者某些类

from pycocotools.coco import COCO
import os
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw

# 需要设置的路径
savepath = "D:\BaiduNetdiskDownload\coco\COCO/car/"
img_dir = savepath + 'images/'
anno_dir = savepath + 'annotations/'
datasets_list = ['train2017']

# coco有80类,这里写要提取类的名字,以car为例
classes_names = ['car','bus','truck']
# 包含所有类别的原coco数据集路径
'''
目录格式如下:
$COCO_PATH
----|annotations
----|train2017
----|val2017
----|test2017
'''
dataDir = 'D:\BaiduNetdiskDownload\coco\COCO/'

headstr = """\
<annotation>
    <folder>VOC</folder>
    <filename>%s</filename>
    <source>
        <database>My Database</database>
        <annotation>COCO</annotation>
        <image>flickr</image>
        <flickrid>NULL</flickrid>
    </source>
    <owner>
        <flickrid>NULL</flickrid>
        <name>company</name>
    </owner>
    <size>
        <width>%d</width>
        <height>%d</height>
        <depth>%d</depth>
    </size>
    <segmented>0</segmented>
"""
objstr = """\
    <object>
        <name>%s</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>%d</xmin>
            <ymin>%d</ymin>
            <xmax>%d</xmax>
            <ymax>%d</ymax>
        </bndbox>
    </object>
"""

tailstr = '''\
</annotation>
'''


# 检查目录是否存在,如果存在,先删除再创建,否则,直接创建
def mkr(path):
    if not os.path.exists(path):
        os.makedirs(path)  # 可以创建多级目录


def id2name(coco):
    classes = dict()
    for cls in coco.dataset['categories']:
        classes[cls['id']] = cls['name']
    return classes


def write_xml(anno_path, head, objs, tail):
    f = open(anno_path, "w")
    f.write(head)
    for obj in objs:
        f.write(objstr % (obj[0], obj[1], obj[2], obj[3], obj[4]))
    f.write(tail)


def save_annotations_and_imgs(coco, dataset, filename, objs):
    # 将图片转为xml,例:COCO_train2017_000000196610.jpg-->COCO_train2017_000000196610.xml
    dst_anno_dir = os.path.join(anno_dir, dataset)
    mkr(dst_anno_dir)
    anno_path = dst_anno_dir + '/' + filename[:-3] + 'xml'
    img_path = dataDir + dataset + '/' + filename
    print("img_path: ", img_path)
    dst_img_dir = os.path.join(img_dir, dataset)
    mkr(dst_img_dir)
    dst_imgpath = dst_img_dir + '/' + filename
    print("dst_imgpath: ", dst_imgpath)
    img = cv2.imread(img_path)
    # if (img.shape[2] == 1):
    #    print(filename + " not a RGB image")
    #   return
    shutil.copy(img_path, dst_imgpath)

    head = headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
    tail = tailstr
    write_xml(anno_path, head, objs, tail)


def showimg(coco, dataset, img, classes, cls_id, show=True):
    global dataDir
    I = Image.open('%s/%s/%s' % (dataDir, dataset, img['file_name']))
    # 通过id,得到注释的信息
    annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
    # print(annIds)
    anns = coco.loadAnns(annIds)
    # print(anns)
    # coco.showAnns(anns)
    objs = []
    for ann in anns:
        class_name = classes[ann['category_id']]
        if class_name in classes_names:
            # print(class_name)
            if 'bbox' in ann:
                bbox = ann['bbox']
                xmin = int(bbox[0])
                ymin = int(bbox[1])
                xmax = int(bbox[2] + bbox[0])
                ymax = int(bbox[3] + bbox[1])
                obj = [class_name, xmin, ymin, xmax, ymax]
                objs.append(obj)
                draw = ImageDraw.Draw(I)
                draw.rectangle([xmin, ymin, xmax, ymax])
    if show:
        plt.figure()
        plt.axis('off')
        plt.imshow(I)
        plt.show()

    return objs


for dataset in datasets_list:
    # ./COCO/annotations/instances_train2017.json
    annFile = '{}/annotations/instances_{}.json'.format(dataDir, dataset)

    # 使用COCO API用来初始化注释数据
    coco = COCO(annFile)

    # 获取COCO数据集中的所有类别
    classes = id2name(coco)
    # print(classes)
    # [1, 2, 3, 4, 6, 8]
    classes_ids = coco.getCatIds(catNms=classes_names)
    # print(classes_ids)
    for cls in classes_names:
        # 获取该类的id
        cls_id = coco.getCatIds(catNms=[cls])
        img_ids = coco.getImgIds(catIds=cls_id)
        # print(cls, len(img_ids))
        # imgIds=img_ids[0:10]
        for imgId in tqdm(img_ids):
            img = coco.loadImgs(imgId)[0]
            filename = img['file_name']
            # print(filename)
            objs = showimg(coco, dataset, img, classes, classes_ids, show=False)
            # print(objs)
            save_annotations_and_imgs(coco, dataset, filename, objs)

二、VOC格式数据集提取特定类

# VOC数据集提取某个类或者某些类
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import xml.etree.ElementTree as ET
import shutil

# 根据自己的情况修改相应的路径
ann_filepath = r'Annotations/'
img_filepath = r'JPEGImages/'
img_savepath = r'imgs/'
ann_savepath = r'xmls/'
if not os.path.exists(img_savepath):
    os.mkdir(img_savepath)

if not os.path.exists(ann_savepath):
    os.mkdir(ann_savepath)

# 这是VOC数据集中所有类别
# classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
#             'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
#              'dog', 'horse', 'motorbike', 'pottedplant',
#           'sheep', 'sofa', 'train', 'person','tvmonitor']

classes = ['car']  # 这里是需要提取的类别

def save_annotation(file):
    tree = ET.parse(ann_filepath + '/' + file)
    root = tree.getroot()
    result = root.findall("object")
    bool_num = 0
    for obj in result:
        if obj.find("name").text not in classes:
            root.remove(obj)
        else:
            bool_num = 1
    if bool_num:
        tree.write(ann_savepath + file)
        return True
    else:
        return False

def save_images(file):
    name_img = img_filepath + os.path.splitext(file)[0] + ".png"
    shutil.copy(name_img, img_savepath)
    # 文本文件名自己定义,主要用于生成相应的训练或测试的txt文件
    with open('train.txt', 'a') as file_txt:
        file_txt.write(os.path.splitext(file)[0])
        file_txt.write("\n")
    return True


if __name__ == '__main__':
    for f in os.listdir(ann_filepath):
        print(f)
        if save_annotation(f):
            save_images(f)

三、VOC格式数据集修改某个类的名字

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import os
import xml.etree.ElementTree as ET

origin_ann_dir = r'xmls/'  # 设置原始标签路径为 Annos
new_ann_dir = r'xmls/'  # 设置新标签路径 Annotations
for dirpaths, dirnames, filenames in os.walk(origin_ann_dir):  # os.walk游走遍历目录名
    for filename in filenames:
        print("process...")
        if os.path.isfile(r'%s%s' % (origin_ann_dir, filename)):  # 获取原始xml文件绝对路径,isfile()检测是否为文件 isdir检测是否为目录
            origin_ann_path = os.path.join(r'%s%s' % (origin_ann_dir, filename))  # 如果是,获取绝对路径(重复代码)
            new_ann_path = os.path.join(r'%s%s' % (new_ann_dir, filename))
            tree = ET.parse(origin_ann_path)  # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
            root = tree.getroot()  # 获取根节点
            for object in root.findall('object'):  # 找到根节点下所有“object”节点
                name = str(object.find('name').text)  # 找到object节点下name子节点的值(字符串)
                # 功能1.删除指定类别的标签。如果name等于str,则删除该节点
                # if (name in ["car_head"]):
                #   root.remove(object)

                # 功能2.修改指定类别的标签。如果name等于str,则修改name
                if (name in ["car","bus","truck"]):           # 将car bus truck三个类改成car类
                    object.find('name').text = "car"

            # # 功能3.删除labelmap中没有的标签。检查是否存在labelmap中没有的类别
            # for object in root.findall('object'):
            #   name = str(object.find('name').text)
            #   if not (name in ["chepai","chedeng","chebiao","person"]):
            #       print(filename + "------------->label is error--->" + name)

            # # 功能4.比对xml中filename名称与图片名称是否一致。如果xml中filename名称与文件名称不一致,则对其进行修改
            # get_name = str(root.find('filename').text)
            # if filename.replace(".xml", ".jpg") != get_name:
            #       print("{}-->name is inconformity!".format(filename))
            #       root.find('filename').text = filename.replace(".xml", ".jpg")
            # else:
            #   continue

            tree.write(new_ann_path)  # tree为文件,write写入新的文件中。


猜你喜欢

转载自blog.csdn.net/qq_39056987/article/details/125082396