上一篇文章写了关于mmrotate旋转框目标检测的推理与部署,这次来总结一下关于YOLOv8模型对于小目标检测的切片推理结果。
首先介绍一下对于大图像的小目标检测所采用的方法技术。
一、为什么需要进行切片推理?什么是切片?
1.1 切片推理
切片推理(Sliced Inference)
是一种通过将大图像分割成多个小块(切片)来进行目标检测的技术。这种方法主要用于提高对小目标和高分辨率图像的检测效果。
高分辨率图像包含大量的像素信息,直接将整个图像输入到模型中可能会导致计算资源的消耗过大,内存不足,或者模型无法处理这么大的输入。通过切片,原本的大图像被分解为多个较小的部分,这样不仅降低了对计算资源的需求,也能避免模型因处理过大输入而性能下降。
切片(Slicing)
是将一张大图像分割成多个较小的矩形区域,每个区域称为一个“切片”。这些切片通常具有一定的重叠度,以确保在图像边缘的目标不会被切分掉或丢失。
- 切片的尺寸: 切片的大小通常根据任务的需求进行设置。较小的切片可以帮助更好地检测小目标,但也可能导致更多的切片和更高的计算量。
- 切片的重叠: 为了确保目标不会因为切片的边界而被切分成多个部分,通常会在切片之间设置一定的重叠区域。这样可以保证目标在至少一个切片中被完整地检测到。
切片推理允许模型聚焦于图像的局部区域,提高每个区域的检测精度,这在目标较密集或重叠的场景中尤为重要。
切片推理的过程
- 图像分割: 将原始大图像切分为多个小块(切片)。
- 逐块推理: 对每个切片独立地进行目标检测。
- 结果整合: 将所有切片的检测结果整合回原始图像的坐标系中。
1.2 SAHI 切片辅助超推理
SAHI: Slicing Aided Hyper Inference
(切片辅助超推理)通过图像切片的方式来检测小目标。SAHI检测过程可以描述为:通过滑动窗口将图像切分成若干区域,各个区域分别进行预测,同时也对整张图片进行推理。然后将各个区域的预测结果和整张图片的预测结果合并,最后用NMS(非极大值抑制)进行过滤。用动图表示该识别过程如下:
SAHI官方地址:https://github.com/obss/sahi
1.3 安装SAHI
SAHI的安装很简单,可以使用以下命令进行安装:
pip install sahi
二、YOLOv8切片推理实现
YOLOv8本身不直接支持切片推理功能,因此需要结合使用SAHI(Slicing Aided Hyper Inference
)工具来实现。
使用SAHI进行切片推理:
创建一个Python脚本用于切片推理。
设置切片的尺寸为512×512,重叠度为0.2
# 使用 Sahi 进行切片推理
result = get_sliced_prediction(
image_path,
detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2
)
完整代码sahi-predict.py如下:
import os
from datetime import datetime
import xml.etree.ElementTree as ET
import xml.dom.minidom as minidom
import cv2
from ultralytics import YOLO
from sahi import AutoDetectionModel
from sahi.utils.cv import read_image
from sahi.utils.file import download_from_url
from sahi.predict import get_prediction, get_sliced_prediction, predict
def create_custom_xml(detections, class_names, image_info, save_path):
# 创建根元素
annotation = ET.Element("annotation")
# 添加 source 元素
source = ET.SubElement(annotation, "source")
filename = ET.SubElement(source, "filename")
filename.text = image_info['filename']
origin = ET.SubElement(source, "origin")
origin.text = "SAR"
# 添加 research 元素
research = ET.SubElement(annotation, "research")
version = ET.SubElement(research, "version")
version.text = "1.0"
author = ET.SubElement(research, "author")
author.text = "MACW"
pluginname = ET.SubElement(research, "pluginname")
pluginname.text = "YOLOv10"
pluginclass = ET.SubElement(research, "pluginclass")
pluginclass.text = "Object Detection"
time = ET.SubElement(research, "time")
time.text = datetime.now().strftime("%Y-%m-%d")
# 添加 size 元素
size = ET.SubElement(annotation, "size")
width = ET.SubElement(size, "width")
width.text = str(image_info['width'])
height = ET.SubElement(size, "height")
height.text = str(image_info['height'])
depth = ET.SubElement(size, "depth")
depth.text = str(image_info['depth'])
# 如果有检测对象,则添加 objects 元素
if len(detections) > 0:
objects = ET.SubElement(annotation, "objects")
# 添加每个检测对象的信息
for detection in detections:
obj = ET.SubElement(objects, "object")
coordinate = ET.SubElement(obj, "coordinate")
coordinate.text = "pixel"
type_ = ET.SubElement(obj, "type")
type_.text = "rectangle"
description = ET.SubElement(obj, "description")
description.text = "Detected object"
possibleresult = ET.SubElement(obj, "possibleresult")
class_id = int(detection[5])
name = ET.SubElement(possibleresult, "name")
name.text = class_names[class_id]
probability = ET.SubElement(possibleresult, "probability")
probability.text = str(detection[4])
points = ET.SubElement(obj, "points")
xmin, ymin, xmax, ymax = detection[0], detection[1], detection[2], detection[3]
points_data = [
(xmin, ymin),
(xmax, ymin),
(xmax, ymax),
(xmin, ymax),
(xmin, ymin) # 闭合多边形
]
for point in points_data:
point_element = ET.SubElement(points, "point")
point_element.text = f"{
point[0]:.6f},{
point[1]:.6f}"
# 创建XML树
rough_string = ET.tostring(annotation, 'utf-8')
reparsed = minidom.parseString(rough_string)
pretty_xml_as_string = reparsed.toprettyxml(indent=" ")
# 保存XML文件
with open(save_path, "w") as f:
f.write(pretty_xml_as_string)
def main():
# 图像输入路径
input_path = '/input_path'
# 图像输出路径
output_path = '/output_path'
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
# model_path="sar/sar/weights/last.pt",
# 模型权重文件
model_path="sar/sar_all-aug7/weights/best.pt",
confidence_threshold=0.3,
device="cuda:0", # or "cpu"
)
if not os.path.exists(output_path):
os.makedirs(output_path)
for filename in os.listdir(input_path):
if filename.endswith(".tif"):
# 提取文件名中的数字部分
file_number = int(os.path.splitext(filename)[0])
if file_number <= 750:
continue
image_path = os.path.join(input_path, filename)
image = cv2.imread(image_path)
height, width, depth = image.shape
print(image_path)
# 使用 Sahi 进行切片推理
result = get_sliced_prediction(
image_path,
detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2
)
detections = result.object_prediction_list
# print(detections)
# 处理 Sahi 的检测结果
detection_data = []
class_names = {
}
for detection in detections:
bbox = detection.bbox
score = detection.score.value
category_id = detection.category.id
category_name = detection.category.name
# 保存分类名称
if category_id not in class_names:
class_names[category_id] = category_name
detection_data.append([
bbox.minx, bbox.miny, bbox.maxx, bbox.maxy,
score, category_id
])
image_info = {
'filename': filename,
'width': width,
'height': height,
'depth': depth
}
save_path = os.path.join(output_path, f"{
os.path.splitext(filename)[0]}.xml")
create_custom_xml(detection_data, class_names, image_info, save_path)
if __name__ == "__main__":
main()
执行上述脚本,YOLOv8模型将对每张图片进行切片推理,并将结果保存为符合你定义格式的XML文件。
模型执行成功后保存的xml文件格式如下:
<?xml version="1.0" ?>
<annotation>
<source>
<filename>1590.tif</filename>
<origin>SAR</origin>
</source>
<research>
<version>1.0</version>
<author>MACW</author>
<pluginname>YOLOv10</pluginname>
<pluginclass>Object Detection</pluginclass>
<time>2024-09-04</time>
</research>
<size>
<width>600</width>
<height>600</height>
<depth>3</depth>
</size>
<objects>
<object>
<coordinate>pixel</coordinate>
<type>rectangle</type>
<description>Detected object</description>
<possibleresult>
<name>A320/321</name>
<probability>0.9679625034332275</probability>
</possibleresult>
<points>
<point>18.868906,490.197571</point>
<point>108.454903,490.197571</point>
<point>108.454903,575.142151</point>
<point>18.868906,575.142151</point>
<point>18.868906,490.197571</point>
</points>
</object>
</objects>
</annotation>
如此,即可实现yolov8模型对于小目标检测进行切片推理并将推理结果保存为xml文件。
如果你用的是旋转框目标检测,可以查看我的另一篇文章:MMRotate旋转框目标检测训练DOTA数据集(模型推理与部署,保存推理结果为xml文件并构建镜像)