前言
在深度学习领域,尤其是目标检测任务中,数据集的质量直接影响模型的性能。为了提升模型的鲁棒性和对各种场景的适应能力,数据增强技术被广泛应用于图像数据集处理。旋转角度是常见的数据增强方法,通过对图像及其对应的标签(边界框)进行同步旋转,可以有效地增加模型对不同方向目标的识别能力。然而,在进行旋转变换时,除了图像本身需要旋转外,与之对应的标签(边界框)也需要进行同步更新。否则,标签位置的偏差将导致模型训练出现误差,甚至无法正确识别目标。
本篇文章将介绍如何在目标检测数据集中的图像和标签(边界框)同步旋转(仍为矩形)。通过同步旋转图像和标签,能够确保数据增强过程中图像与标签的一致性,从而提高模型训练的效果和准确性。
效果如图,倾斜旋转。
代码
import os
import cv2
import math
import shutil
import numpy as np
from lxml import etree, objectify
class ImageXMLProcessor:
def __init__(self, input_img_folder, input_xml_folder, output_img_folder, output_xml_folder, rotation_angle=5, scale=1):
self.input_img_folder = input_img_folder
self.input_xml_folder = input_xml_folder
self.output_img_folder = output_img_folder
self.output_xml_folder = output_xml_folder
self.rotation_angle = rotation_angle
self.scale = scale
os.makedirs(self.output_img_folder, exist_ok=True)
os.makedirs(self.output_xml_folder, exist_ok=True)
def rotate_img_bbox(self, img, bboxes, angle=5, scale=1.):
w = img.shape[1]
h = img.shape[0]
rangle = np.deg2rad(angle)
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) * scale
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) * scale
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
rot_mat[0, 2] += rot_move[0]
rot_mat[1, 2] += rot_move[1]
rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
rot_bboxes = list()
for bbox in bboxes:
xmin = bbox[0]
ymin = bbox[1]
xmax = bbox[2]
ymax = bbox[3]
point1 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymin, 1]))
point2 = np.dot(rot_mat, np.array([xmax, (ymin + ymax) / 2, 1]))
point3 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymax, 1]))
point4 = np.dot(rot_mat, np.array([xmin, (ymin + ymax) / 2, 1]))
concat = np.vstack((point1, point2, point3, point4))
concat = concat.astype(np.int32)
rx, ry, rw, rh = cv2.boundingRect(concat)
rx_min = rx
ry_min = ry
rx_max = rx + rw
ry_max = ry + rh
rot_bboxes.append([rx_min, ry_min, rx_max, ry_max])
return rot_img, rot_bboxes
def save_xml(self, file_name, save_folder, img_info, height, width, channel, bboxs_info):
folder_name, img_name = img_info # 得到图片的信息
E = objectify.ElementMaker(annotate=False)
# 创建 XML 树结构
anno_tree = E.annotation(
E.folder(folder_name),
E.filename(img_name),
E.path(os.path.join(folder_name, img_name)),
E.source(
E.database('Unknown'),
),
E.size(
E.width(width),
E.height(height),
E.depth(channel)
),
E.segmented(0),
)
labels, bboxs = bboxs_info # 得到边框和标签信息
for label, box in zip(labels, bboxs):
anno_tree.append(
E.object(
E.name(label),
E.pose('Unspecified'),
E.truncated('0'),
E.difficult('0'),
E.bndbox(
E.xmin(box[0]),
E.ymin(box[1]),
E.xmax(box[2]),
E.ymax(box[3])
)
))
os.makedirs(save_folder, exist_ok=True)
file_path = os.path.join(save_folder, file_name)
etree.ElementTree(anno_tree).write(file_path, pretty_print=True, xml_declaration=True, encoding="UTF-8")
print(f"XML 文件已保存到: {file_path}")
def update_bboxes_in_xml(self, xml_path, bboxes):
tree = etree.parse(xml_path)
root = tree.getroot()
# 找到所有的 <object> 元素并更新对应的 <bndbox> 坐标
obj_elements = root.findall("object")
for idx, object_element in enumerate(obj_elements):
# 获取原始的 bndbox
bndbox_element = object_element.find("bndbox")
if bndbox_element is not None:
# 只有当 bbox 存在时才更新
xmin_element = bndbox_element.find("xmin")
ymin_element = bndbox_element.find("ymin")
xmax_element = bndbox_element.find("xmax")
ymax_element = bndbox_element.find("ymax")
# 更新坐标
if idx < len(bboxes):
bbox = bboxes[idx]
xmin_element.text = str(bbox[0])
ymin_element.text = str(bbox[1])
xmax_element.text = str(bbox[2])
ymax_element.text = str(bbox[3])
# 保存更新后的 XML 文件
tree.write(xml_path, pretty_print=True, xml_declaration=True, encoding="UTF-8")
print(f"XML 文件已更新并保存到: {xml_path}")
def process(self):
# 遍历输入文件夹中的所有图片和XML文件
for img_file in os.listdir(self.input_img_folder):
if img_file.endswith('.jpg') or img_file.endswith('.png'): # 支持jpg/png格式
# 获取对应的XML文件
xml_file = os.path.splitext(img_file)[0] + '.xml'
img_path = os.path.join(self.input_img_folder, img_file)
xml_path = os.path.join(self.input_xml_folder, xml_file)
if os.path.exists(xml_path):
shutil.copy(xml_path, self.output_xml_folder)
img = cv2.imread(img_path)
height, width, channel = img.shape
tree = etree.parse(xml_path)
root = tree.getroot()
labels = []
bboxes = []
for obj in root.findall('object'):
label = obj.find('name').text
xmin = int(obj.find('bndbox/xmin').text)
ymin = int(obj.find('bndbox/ymin').text)
xmax = int(obj.find('bndbox/xmax').text)
ymax = int(obj.find('bndbox/ymax').text)
labels.append(label)
bboxes.append([xmin, ymin, xmax, ymax])
rotated_img, rotated_bboxes = self.rotate_img_bbox(img, bboxes, self.rotation_angle, self.scale)
output_img_path = os.path.join(self.output_img_folder, img_file)
cv2.imwrite(output_img_path, rotated_img)
output_xml_path = os.path.join(self.output_xml_folder, xml_file)
self.save_xml(xml_file, self.output_xml_folder, ('folder_name', img_file), height, width, channel, (labels, rotated_bboxes))
else:
print(f"XML 文件不存在: {xml_path}")
if __name__ == "__main__":
input_img_folder = r"E:\data\x1" # 输入图片文件夹
input_xml_folder = r"E:\data\x2" # 输入 XML 文件夹
output_img_folder = r"E:\data\y1" # 输出图片文件夹
output_xml_folder = r"E:\data\y2" # 输出 XML 文件夹
rotation_angle = 5 # 旋转角度
scale = 1 #图片尺寸,默认为1
processor = ImageXMLProcessor(input_img_folder, input_xml_folder, output_img_folder, output_xml_folder, rotation_angle, scale)
processor.process()