生成数据集代码

1.生成txt

import xml.etree.ElementTree as ET
from os import getcwd

sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]

classes = ["yiyun","sudashui","milk4","binghongcha","ruishijuan","liuliumei","xiatiao","suannai","huangpi","taozhi","guozhi","xiaowanzi","milk5","xuanmai","xizhilang","yanmai","paomian2","fangbianmian","rusuan","jinguan","baiwei","meinianda","anmuxi","milk2","milk","milk3","hongniu","zhenguoli","baisuishan","7up","kekoukele","binggan","candou","beer","liugehetao","yimian","baishikele","maidong","quechao","miaofu","paomian","guodong","1664beer","water","wanglaoji","aoliao","beibingyang","shupian"]


def convert_annotation(year, image_id, list_file):
    in_file = open('/media/chenyu/2C4441FE4441CAF2/VOC%s/Annotations/%s.xml'%(year, image_id))
    tree=ET.parse(in_file)
    root = tree.getroot()

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))
        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))

#wd = getcwd()

for year, image_set in sets:
    image_ids = open('/media/chenyu/2C4441FE4441CAF2/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
    list_file = open('%s_%s.txt'%(year, image_set), 'w')
    i=0
    for image_id in image_ids:
        list_file.write('/media/chenyu/2C4441FE4441CAF2/VOC%s/JPEGImages/%s.jpg'%(year, image_id))
        #list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg'%(wd, year, image_id))
        convert_annotation(year, image_id, list_file)
        list_file.write('\n')
        i+=1
        print(i)
    list_file.close()

2.生成scv

# -*- coding:utf-8 -*-

import csv
import os
import glob
import sys

class PascalVOC2CSV(object):
    def __init__(self,xml=[], ann_path='./CSV/annotations.csv',classes_path='./CSV/classes.csv'):
        '''
        :param xml: 所有Pascal VOC的xml文件路径组成的列表
        :param ann_path: ann_path
        :param classes_path: classes_path
        '''
        self.xml = xml
        self.ann_path = ann_path
        self.classes_path=classes_path
        self.label=[]
        self.annotations=[]

        self.data_transfer()
        self.write_file()


    def data_transfer(self):
        for num, xml_file in enumerate(self.xml):
            try:
                # print(xml_file)
                # 进度输出
                sys.stdout.write('\r>> Converting image %d/%d' % (
                    num + 1, len(self.xml)))
                sys.stdout.flush()

                with open(xml_file, 'r') as fp:
                    for p in fp:
                        if '<filename>' in p:
                            self.filen_ame = (p.split('>')[1].split('<')[0])[:-4]+'.jpg'

                        if '<object>' in p:
                            # 类别
                            d = [next(fp).split('>')[1].split('<')[0] for _ in range(9)]
                            self.supercategory = d[0]
                            if self.supercategory not in self.label:
                                self.label.append(self.supercategory)

                            # 边界框
                            x1 = int(d[-4]);
                            y1 = int(d[-3]);
                            x2 = int(d[-2]);
                            y2 = int(d[-1])

                            self.annotations.append([os.path.join(filepath,'JPEGImages',self.filen_ame),x1,y1,x2,y2,self.supercategory])
            except:
                continue

        sys.stdout.write('\n')
        sys.stdout.flush()

    def write_file(self,):
        with open(self.ann_path, 'w') as fp:
            csv_writer = csv.writer(fp, dialect='excel')
            csv_writer.writerows(self.annotations)

        class_name=sorted(self.label)
        class_=[]
        for num,name in enumerate(class_name):
            class_.append([name,num])
        with open(self.classes_path, 'w') as fp:
            csv_writer = csv.writer(fp, dialect='excel')
            csv_writer.writerows(class_)


xml_file = glob.glob('/media/chenyu/2C4441FE4441CAF2/VOC2007/Annotations/*.xml')
filepath='/media/chenyu/2C4441FE4441CAF2/VOC2007'
PascalVOC2CSV(xml_file)

猜你喜欢

转载自blog.csdn.net/weixin_38740463/article/details/85694197
今日推荐