目标检测数据集图片及标签同步旋转角度

前言

在深度学习领域,尤其是目标检测任务中,数据集的质量直接影响模型的性能。为了提升模型的鲁棒性和对各种场景的适应能力,数据增强技术被广泛应用于图像数据集处理。旋转角度是常见的数据增强方法,通过对图像及其对应的标签(边界框)进行同步旋转,可以有效地增加模型对不同方向目标的识别能力。然而,在进行旋转变换时,除了图像本身需要旋转外,与之对应的标签(边界框)也需要进行同步更新。否则,标签位置的偏差将导致模型训练出现误差,甚至无法正确识别目标。

本篇文章将介绍如何在目标检测数据集中的图像和标签(边界框)同步旋转(仍为矩形)。通过同步旋转图像和标签,能够确保数据增强过程中图像与标签的一致性,从而提高模型训练的效果和准确性。

效果如图,倾斜旋转。

代码

import os
import cv2
import math
import shutil
import numpy as np
from lxml import etree, objectify

class ImageXMLProcessor:
    def __init__(self, input_img_folder, input_xml_folder, output_img_folder, output_xml_folder, rotation_angle=5, scale=1):
        self.input_img_folder = input_img_folder
        self.input_xml_folder = input_xml_folder
        self.output_img_folder = output_img_folder
        self.output_xml_folder = output_xml_folder
        self.rotation_angle = rotation_angle
        self.scale = scale
        os.makedirs(self.output_img_folder, exist_ok=True)
        os.makedirs(self.output_xml_folder, exist_ok=True)

    def rotate_img_bbox(self, img, bboxes, angle=5, scale=1.):
        w = img.shape[1]
        h = img.shape[0]
        rangle = np.deg2rad(angle)
        nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) * scale
        nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) * scale
        rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
        rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
        rot_mat[0, 2] += rot_move[0]
        rot_mat[1, 2] += rot_move[1]
        rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
        rot_bboxes = list()
        for bbox in bboxes:
            xmin = bbox[0]
            ymin = bbox[1]
            xmax = bbox[2]
            ymax = bbox[3]
            point1 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymin, 1]))
            point2 = np.dot(rot_mat, np.array([xmax, (ymin + ymax) / 2, 1]))
            point3 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymax, 1]))
            point4 = np.dot(rot_mat, np.array([xmin, (ymin + ymax) / 2, 1]))
            concat = np.vstack((point1, point2, point3, point4))
            concat = concat.astype(np.int32)
            rx, ry, rw, rh = cv2.boundingRect(concat)
            rx_min = rx
            ry_min = ry
            rx_max = rx + rw
            ry_max = ry + rh
            rot_bboxes.append([rx_min, ry_min, rx_max, ry_max])

        return rot_img, rot_bboxes

    def save_xml(self, file_name, save_folder, img_info, height, width, channel, bboxs_info):
        folder_name, img_name = img_info  # 得到图片的信息

        E = objectify.ElementMaker(annotate=False)

        # 创建 XML 树结构
        anno_tree = E.annotation(
            E.folder(folder_name),
            E.filename(img_name),
            E.path(os.path.join(folder_name, img_name)),
            E.source(
                E.database('Unknown'),
            ),
            E.size(
                E.width(width),
                E.height(height),
                E.depth(channel)
            ),
            E.segmented(0),
        )

        labels, bboxs = bboxs_info  # 得到边框和标签信息
        for label, box in zip(labels, bboxs):
            anno_tree.append(
                E.object(
                    E.name(label),
                    E.pose('Unspecified'),
                    E.truncated('0'),
                    E.difficult('0'),
                    E.bndbox(
                        E.xmin(box[0]),
                        E.ymin(box[1]),
                        E.xmax(box[2]),
                        E.ymax(box[3])
                    )
                ))

        os.makedirs(save_folder, exist_ok=True)
        file_path = os.path.join(save_folder, file_name)
        etree.ElementTree(anno_tree).write(file_path, pretty_print=True, xml_declaration=True, encoding="UTF-8")
        print(f"XML 文件已保存到: {file_path}")

    def update_bboxes_in_xml(self, xml_path, bboxes):
        tree = etree.parse(xml_path)
        root = tree.getroot()
        # 找到所有的 <object> 元素并更新对应的 <bndbox> 坐标
        obj_elements = root.findall("object")

        for idx, object_element in enumerate(obj_elements):
            # 获取原始的 bndbox
            bndbox_element = object_element.find("bndbox")
            if bndbox_element is not None:
                # 只有当 bbox 存在时才更新
                xmin_element = bndbox_element.find("xmin")
                ymin_element = bndbox_element.find("ymin")
                xmax_element = bndbox_element.find("xmax")
                ymax_element = bndbox_element.find("ymax")

                # 更新坐标
                if idx < len(bboxes):
                    bbox = bboxes[idx]
                    xmin_element.text = str(bbox[0])
                    ymin_element.text = str(bbox[1])
                    xmax_element.text = str(bbox[2])
                    ymax_element.text = str(bbox[3])
        # 保存更新后的 XML 文件
        tree.write(xml_path, pretty_print=True, xml_declaration=True, encoding="UTF-8")
        print(f"XML 文件已更新并保存到: {xml_path}")

    def process(self):
        # 遍历输入文件夹中的所有图片和XML文件
        for img_file in os.listdir(self.input_img_folder):
            if img_file.endswith('.jpg') or img_file.endswith('.png'):  # 支持jpg/png格式
                # 获取对应的XML文件
                xml_file = os.path.splitext(img_file)[0] + '.xml'
                img_path = os.path.join(self.input_img_folder, img_file)
                xml_path = os.path.join(self.input_xml_folder, xml_file)
                if os.path.exists(xml_path):
                    shutil.copy(xml_path, self.output_xml_folder)

                    img = cv2.imread(img_path)
                    height, width, channel = img.shape

                    tree = etree.parse(xml_path)
                    root = tree.getroot()
                    labels = []
                    bboxes = []
                    for obj in root.findall('object'):
                        label = obj.find('name').text
                        xmin = int(obj.find('bndbox/xmin').text)
                        ymin = int(obj.find('bndbox/ymin').text)
                        xmax = int(obj.find('bndbox/xmax').text)
                        ymax = int(obj.find('bndbox/ymax').text)

                        labels.append(label)
                        bboxes.append([xmin, ymin, xmax, ymax])

                    rotated_img, rotated_bboxes = self.rotate_img_bbox(img, bboxes, self.rotation_angle, self.scale)

                    output_img_path = os.path.join(self.output_img_folder, img_file)
                    cv2.imwrite(output_img_path, rotated_img)

                    output_xml_path = os.path.join(self.output_xml_folder, xml_file)
                    self.save_xml(xml_file, self.output_xml_folder, ('folder_name', img_file), height, width, channel, (labels, rotated_bboxes))

                else:
                    print(f"XML 文件不存在: {xml_path}")

if __name__ == "__main__":
    input_img_folder = r"E:\data\x1"  # 输入图片文件夹
    input_xml_folder = r"E:\data\x2"  # 输入 XML 文件夹
    output_img_folder = r"E:\data\y1"  # 输出图片文件夹
    output_xml_folder = r"E:\data\y2"  # 输出 XML 文件夹
    rotation_angle = 5  # 旋转角度
    scale = 1 #图片尺寸,默认为1

    processor = ImageXMLProcessor(input_img_folder, input_xml_folder, output_img_folder, output_xml_folder, rotation_angle, scale)
    processor.process()