yolov8模型对于小目标检测进行切片推理并将推理结果保存为xml文件

上一篇文章写了关于mmrotate旋转框目标检测的推理与部署,这次来总结一下关于YOLOv8模型对于小目标检测的切片推理结果。
首先介绍一下对于大图像的小目标检测所采用的方法技术。

一、为什么需要进行切片推理?什么是切片?

1.1 切片推理

切片推理(Sliced Inference)是一种通过将大图像分割成多个小块(切片)来进行目标检测的技术。这种方法主要用于提高对小目标和高分辨率图像的检测效果。

高分辨率图像包含大量的像素信息,直接将整个图像输入到模型中可能会导致计算资源的消耗过大,内存不足,或者模型无法处理这么大的输入。通过切片,原本的大图像被分解为多个较小的部分,这样不仅降低了对计算资源的需求,也能避免模型因处理过大输入而性能下降。

切片(Slicing) 是将一张大图像分割成多个较小的矩形区域,每个区域称为一个“切片”。这些切片通常具有一定的重叠度,以确保在图像边缘的目标不会被切分掉或丢失。

  • 切片的尺寸: 切片的大小通常根据任务的需求进行设置。较小的切片可以帮助更好地检测小目标,但也可能导致更多的切片和更高的计算量。
  • 切片的重叠: 为了确保目标不会因为切片的边界而被切分成多个部分,通常会在切片之间设置一定的重叠区域。这样可以保证目标在至少一个切片中被完整地检测到。

切片推理允许模型聚焦于图像的局部区域,提高每个区域的检测精度,这在目标较密集或重叠的场景中尤为重要。

切片推理的过程

  1. 图像分割: 将原始大图像切分为多个小块(切片)。
  2. 逐块推理: 对每个切片独立地进行目标检测。
  3. 结果整合: 将所有切片的检测结果整合回原始图像的坐标系中。

1.2 SAHI 切片辅助超推理

SAHI: Slicing Aided Hyper Inference(切片辅助超推理)通过图像切片的方式来检测小目标。SAHI检测过程可以描述为:通过滑动窗口将图像切分成若干区域,各个区域分别进行预测,同时也对整张图片进行推理。然后将各个区域的预测结果和整张图片的预测结果合并,最后用NMS(非极大值抑制)进行过滤。用动图表示该识别过程如下:

SAHI官方地址:https://github.com/obss/sahi

在这里插入图片描述

1.3 安装SAHI

SAHI的安装很简单,可以使用以下命令进行安装:

pip install sahi

二、YOLOv8切片推理实现

YOLOv8本身不直接支持切片推理功能,因此需要结合使用SAHI(Slicing Aided Hyper Inference)工具来实现。

使用SAHI进行切片推理:
创建一个Python脚本用于切片推理。

设置切片的尺寸为512×512,重叠度为0.2

           # 使用 Sahi 进行切片推理
            result = get_sliced_prediction(
                image_path,
                detection_model,
                slice_height=512,
                slice_width=512,
                overlap_height_ratio=0.2,
                overlap_width_ratio=0.2
            )

完整代码sahi-predict.py如下:

import os
from datetime import datetime
import xml.etree.ElementTree as ET
import xml.dom.minidom as minidom
import cv2
from ultralytics import YOLO
from sahi import AutoDetectionModel
from sahi.utils.cv import read_image
from sahi.utils.file import download_from_url
from sahi.predict import get_prediction, get_sliced_prediction, predict


def create_custom_xml(detections, class_names, image_info, save_path):
    # 创建根元素
    annotation = ET.Element("annotation")
    
    # 添加 source 元素
    source = ET.SubElement(annotation, "source")
    filename = ET.SubElement(source, "filename")
    filename.text = image_info['filename']
    origin = ET.SubElement(source, "origin")
    origin.text = "SAR"

    # 添加 research 元素
    research = ET.SubElement(annotation, "research")
    version = ET.SubElement(research, "version")
    version.text = "1.0"
    author = ET.SubElement(research, "author")
    author.text = "MACW"
    pluginname = ET.SubElement(research, "pluginname")
    pluginname.text = "YOLOv10"
    pluginclass = ET.SubElement(research, "pluginclass")
    pluginclass.text = "Object Detection"
    time = ET.SubElement(research, "time")
    time.text = datetime.now().strftime("%Y-%m-%d")

    # 添加 size 元素
    size = ET.SubElement(annotation, "size")
    width = ET.SubElement(size, "width")
    width.text = str(image_info['width'])
    height = ET.SubElement(size, "height")
    height.text = str(image_info['height'])
    depth = ET.SubElement(size, "depth")
    depth.text = str(image_info['depth'])

    # 如果有检测对象,则添加 objects 元素
    if len(detections) > 0:
        objects = ET.SubElement(annotation, "objects")

        # 添加每个检测对象的信息
        for detection in detections:
            obj = ET.SubElement(objects, "object")
            
            coordinate = ET.SubElement(obj, "coordinate")
            coordinate.text = "pixel"
            
            type_ = ET.SubElement(obj, "type")
            type_.text = "rectangle"
            
            description = ET.SubElement(obj, "description")
            description.text = "Detected object"
            
            possibleresult = ET.SubElement(obj, "possibleresult")
            class_id = int(detection[5])
            name = ET.SubElement(possibleresult, "name")
            name.text = class_names[class_id]
            probability = ET.SubElement(possibleresult, "probability")
            probability.text = str(detection[4])
            
            points = ET.SubElement(obj, "points")
            xmin, ymin, xmax, ymax = detection[0], detection[1], detection[2], detection[3]
            points_data = [
                (xmin, ymin),
                (xmax, ymin),
                (xmax, ymax),
                (xmin, ymax),
                (xmin, ymin)  # 闭合多边形
            ]
            
            for point in points_data:
                point_element = ET.SubElement(points, "point")
                point_element.text = f"{
      
      point[0]:.6f},{
      
      point[1]:.6f}"

    # 创建XML树
    rough_string = ET.tostring(annotation, 'utf-8')
    reparsed = minidom.parseString(rough_string)
    pretty_xml_as_string = reparsed.toprettyxml(indent="    ")

    # 保存XML文件
    with open(save_path, "w") as f:
        f.write(pretty_xml_as_string)

def main():
	# 图像输入路径
    input_path = '/input_path'
    # 图像输出路径
    output_path = '/output_path'
    detection_model = AutoDetectionModel.from_pretrained(
        model_type='yolov8',
        # model_path="sar/sar/weights/last.pt",
        # 模型权重文件
        model_path="sar/sar_all-aug7/weights/best.pt",
        confidence_threshold=0.3,
        device="cuda:0",  # or "cpu"
    )

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    for filename in os.listdir(input_path):
        if filename.endswith(".tif"):
            # 提取文件名中的数字部分
            file_number = int(os.path.splitext(filename)[0])
            if file_number <= 750:
                continue

            image_path = os.path.join(input_path, filename)
            image = cv2.imread(image_path)
            height, width, depth = image.shape
            print(image_path)
            # 使用 Sahi 进行切片推理
            result = get_sliced_prediction(
                image_path,
                detection_model,
                slice_height=512,
                slice_width=512,
                overlap_height_ratio=0.2,
                overlap_width_ratio=0.2
            )
            detections = result.object_prediction_list
            # print(detections)

            
            # 处理 Sahi 的检测结果           
            detection_data = []
            class_names = {
    
    }
            for detection in detections:
                bbox = detection.bbox
                score = detection.score.value
                category_id = detection.category.id
                category_name = detection.category.name           

                # 保存分类名称
                if category_id not in class_names:
                    class_names[category_id] = category_name

                
                detection_data.append([
                    bbox.minx, bbox.miny, bbox.maxx, bbox.maxy,
                    score, category_id
                ])

            
            image_info = {
    
    
                'filename': filename,
                'width': width,
                'height': height,
                'depth': depth
            }

            save_path = os.path.join(output_path, f"{
      
      os.path.splitext(filename)[0]}.xml")
            create_custom_xml(detection_data, class_names, image_info, save_path)

if __name__ == "__main__":
    main()

执行上述脚本,YOLOv8模型将对每张图片进行切片推理,并将结果保存为符合你定义格式的XML文件。

模型执行成功后保存的xml文件格式如下:

<?xml version="1.0" ?>
<annotation>
    <source>
        <filename>1590.tif</filename>
        <origin>SAR</origin>
    </source>
    <research>
        <version>1.0</version>
        <author>MACW</author>
        <pluginname>YOLOv10</pluginname>
        <pluginclass>Object Detection</pluginclass>
        <time>2024-09-04</time>
    </research>
    <size>
        <width>600</width>
        <height>600</height>
        <depth>3</depth>
    </size>
    <objects>
        <object>
            <coordinate>pixel</coordinate>
            <type>rectangle</type>
            <description>Detected object</description>
            <possibleresult>
                <name>A320/321</name>
                <probability>0.9679625034332275</probability>
            </possibleresult>
            <points>
                <point>18.868906,490.197571</point>
                <point>108.454903,490.197571</point>
                <point>108.454903,575.142151</point>
                <point>18.868906,575.142151</point>
                <point>18.868906,490.197571</point>
            </points>
        </object>
    </objects>
</annotation>

如此,即可实现yolov8模型对于小目标检测进行切片推理并将推理结果保存为xml文件。

如果你用的是旋转框目标检测,可以查看我的另一篇文章:MMRotate旋转框目标检测训练DOTA数据集(模型推理与部署,保存推理结果为xml文件并构建镜像)

猜你喜欢

转载自blog.csdn.net/MacWx/article/details/141906664