imagenet数据标注文件的read和write



class PicAnno:
    objects = []

    def __init__(self, folder):
        self.objects = []
        self.folder = folder

    def set_folder(self, folder):
        self.folder = folder

    def set_filename(self, filename):
        self.filename = filename

    def set_size(self, width, height, depth):
        self.width = width
        self.height = height
        self.depth = depth

    def add_object(self, object):
        self.objects.append(object)


class PicObject:
    def __init__(self, name):
        self.name = name

    def set_name(self, name):
        self.name = name

    def set_pose(self, pose):
        self.pose = pose

    def set_truncated(self, truncated):
        self.truncated = truncated

    def set_difficult(self, difficult):
        self.difficult = difficult

    def set_bndbox(self, xmin, ymin, xmax, ymax):
        self.xmin = xmin
        self.ymin = ymin
        self.xmax = xmax
        self.ymax = ymax


class VocUtil:
    def read_anno_xml(self, xml_path):
        tree = etree.parse(xml_path)
        root = tree.getroot()
        picAnno = PicAnno(root.xpath('/annotation/folder')[0].text)
        picAnno.set_filename(root.xpath('/annotation/filename')[0].text)
        picAnno.set_size(root.xpath('/annotation/size/width')[0].text,
                         root.xpath('/annotation/size/height')[0].text,
                         root.xpath('/annotation/size/depth')[0].text)
        for obj in root.xpath('/annotation/object'):
            picObject = PicObject(obj.xpath('name')[0].text)
            picObject.set_pose(obj.xpath('pose')[0].text)
            picObject.set_truncated(obj.xpath('truncated')[0].text)
            picObject.set_difficult(obj.xpath('difficult')[0].text)
            picObject.set_bndbox(obj.xpath('bndbox/xmin')[0].text,
                                 obj.xpath('bndbox/ymin')[0].text,
                                 obj.xpath('bndbox/xmax')[0].text,
                                 obj.xpath('bndbox/ymax')[0].text)
            picAnno.add_object(picObject)
        return picAnno

    def parse_anno_xml(self, picAnno):
        node_root = Element('annotation')
        node_folder = SubElement(node_root, 'folder')

        if hasattr(picAnno, 'folder') and picAnno.folder is not None:
            node_folder.text = picAnno.folder

        node_filename = SubElement(node_root, 'filename')
        if hasattr(picAnno, 'filename') and picAnno.filename is not None:
            node_filename.text = picAnno.filename

        node_size = SubElement(node_root, 'size')
        node_width = SubElement(node_size, 'width')
        if hasattr(picAnno, 'width') and picAnno.width is not None:
            node_width.text = str(picAnno.width)

        node_height = SubElement(node_size, 'height')
        if hasattr(picAnno, 'height') and picAnno.height is not None:
            node_height.text = str(picAnno.height)

        node_depth = SubElement(node_size, 'depth')
        if picAnno.depth is not None:
            node_depth.text = str(picAnno.depth)

        if len(picAnno.objects) > 0:
            for obj in picAnno.objects:
                node_object = SubElement(node_root, 'object')
                node_name = SubElement(node_object, 'name')
                if hasattr(obj, 'name') and obj.name is not None:
                    node_name.text = obj.name
                node_pose = SubElement(node_object, 'pose')
                if hasattr(obj, 'pose') and obj.pose is not None:
                    node_pose.text = str(obj.pose)
                node_truncated = SubElement(node_object, 'truncated')
                if hasattr(obj, 'truncated') and obj.truncated is not None:
                    node_truncated.text = str(obj.truncated)
                node_difficult = SubElement(node_object, 'difficult')
                if hasattr(obj, 'difficult') and obj.difficult is not None:
                    node_difficult.text = str(obj.difficult)
                node_bndbox = SubElement(node_object, 'bndbox')
                node_xmin = SubElement(node_bndbox, 'xmin')
                if hasattr(obj, 'xmin') and obj.xmin is not None:
                    node_xmin.text = str(obj.xmin)
                node_ymin = SubElement(node_bndbox, 'ymin')
                if hasattr(obj, 'ymin') and obj.ymin is not None:
                    node_ymin.text = str(obj.ymin)
                node_xmax = SubElement(node_bndbox, 'xmax')
                if hasattr(obj, 'xmax') and obj.xmax is not None:
                    node_xmax.text = str(obj.xmax)
                node_ymax = SubElement(node_bndbox, 'ymax')
                if hasattr(obj, 'ymax') and obj.ymax is not None:
                    node_ymax.text = str(obj.ymax)

        xml = tostring(node_root, pretty_print=True)
        xml_txt = str(xml,encoding='utf-8')
        return xml_txt

    def save_anno_xml(self, xml_path, xml_text):
        with open(xml_path, 'w') as f:
            f.write(xml_text)

猜你喜欢

转载自blog.csdn.net/otengyue/article/details/79243559
今日推荐