YOLOv11改进 | 网络结构代码逐行解析(四) | 手把手带你理解YOLOv11检测头输出到损失函数计算(新手入门必读系列)

一、本文介绍

本文给大家带来的是YOLOv11中从检测头结构分析到损失函数各种计算的详解,本文将从检测头的网络结构讲起,同时分析其中的原理(包括代码和网络结构图对比),最重要的是分析检测头的输出,因为检测头的输出是需要输出给损失函数的计算不同阶段的输出不一样所以我们在讲损失函数计算的时候需要先明白检测头的输出和其中的一些参数的定义,本文内容为我独家整理和分析,手打每一行的代码分析并包含各种举例分析对于小白来说绝对有所收获,全文共1万1千字。

 专栏回顾:YOLOv11改进系列专栏——本专栏持续复习各种顶会内容——科研必备


二、原理介绍

其实YOLOv11的检测头的网络结构非常的简单,其中主要就是解耦头,所以本文的主要内容是讲解其中的输出的含义,本文的内容会分为两节介绍,本文为章节上!


2.1 YOLOv11多头

其中有检测头部分被我用红框标注,其中有三个重复的结构,其中任何一个挑出即为我们的解耦头(即分开计算BboxLoss和ClsLoss)。

这三个部分大家可以看到仅是输入的特征图大小和通道数不一致,这也是我们所谓的大目标检测头和小目标检测头的区别(即20x20 和 80 x 80 区别。)  | 所以大目标检测头和小目标检测头其实调用的是同一个代码只是从Neck部分输入给检测头的特征图不一样从而产生了大目标和小目标检测头的定义而已。


2.2 YOLOv8解耦头代码分析 

YOLOv8检测头的分析在项目仓库的'ultralytics/nn/modules/head.py'仓库下,大家可以去找到现在改版了但是其实内容还是这些只是增加了一些功能但是对于YOLOv10是没有用到的。

import torch
import torch.nn as nn
import math


class Detect(nn.Module):
    """YOLOv8 Detect head for detection models."""

    dynamic = False  # 推理时的参数是否是支持动态推理
    export = False  # 导出模式
    shape = None  # 可以保存一个预期的特张图的形状后面和实际输入特张图形状判断如果不同则进行额外处理,防止报错。
    anchors = torch.empty(0)  # 初始化 anchor 点
    strides = torch.empty(0)  # 初始化 strides

    def __init__(self, nc=80, ch=()):
        """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
        super().__init__()
        self.nc = nc  # 类别数量(我们数据集中类别的数量,比如我的数据集中共有16个类别那么这里就是16)
        self.nl = len(ch)  # 检测层的数量(将Neck中的多少层输入给检测头,类似于默认的YOLOv8输入为(64, 128, 256)那么经过len(ch), self.nl= 3 因为元组中共有三个数字)

        self.reg_max = 16  # DFL 通道数量 (每个预测框的回归输出通道数, 具体来说就是YOLOv8是Anchor Free模型,
        # 对于每一个预测框我们需要确定四个坐标才能够知道其形状(也就是我们的检测出来物体的外部框形状)四个坐标分别为
        # x - 边界框的中心点 x 坐标
        # y - 边界框的中心点 y 坐标
        # w - 边界框的宽度
        # h - 边界框的高度

        self.no = nc + self.reg_max * 4  # 每个 anchor 的通道输出数量 (包括类别和回归信息) | 我这里的计算 是25 + 16 x 4 = 89 (后面会用到这个参数89)
        # 首先nc是我们的类别数量, self.reg_max为DFL通道数量,为什么要 * 4 代表每一个点的信息我们用16个通道与预测,四个点即 16 x 4 = 64

        self.stride = torch.zeros(self.nl)  # 初始化 stride, stride的含义即缩放比例是特征图相对于原始输入图像的缩放比例。
        # 假设原始输入图像的大小为 640x640,特征图的大小为 80x80,那么 stride 就是640÷80=8。这意味着输入图像中每 8 个像素对应特征图中的一个像素。

        # 计算两个不同通道数的参数 c2 和 c3,用于构建后续的卷积层, 这个参数主要决定计算量是一些中间那通道数.
        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))

        # 定义两个卷积层序列列表,一个用于回归,一个用于分类, 这里就是我们网络结构图中的四个Conv!
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
        )
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)

        # 定义一个 DFL 层,如果 reg_max > 1,则使用自定义的 DFL,否则使用 Identity, 默认为reg_max为16即DFL层被使用!
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        for i in range(self.nl):
            # 对每个检测层的特征图,分别通过回归和分类卷积层处理,然后拼接
            # 对于基础YOLOv8n我们的特征图输入分别为, 三个检测层的输入分别为!
            # (batch_size, 64, 80, 80)
            # (batch_size, 128, 40, 40)
            # (batch_size, 256, 20, 20)
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
            # 这里处理后的形状变为
            # (batch_size, 89, 80, 80)
            # (batch_size, 89, 40, 40)
            # (batch_size, 89, 20, 20)
            # 89 这个参数我们在之前讲过了如何计算得来的, 其中包含了16 x 4 = 64 为坐标的信息, 剩余通道包含了cls类别的信息!
        if self.training:  # 如果是训练则到此就结束了!
            return x # 训练时直接返回 x 即上面的(batch_size, 89, 80, 80)

        # 推理开始 或者是训练时候的验证阶段
        shape = x[0].shape  # BCHW (批量大小, 通道数, 高度, 宽度) 需要注意的是验证阶段的batch会自动变为训练阶段的2倍, 假设你训练时候的batch为 4 那么此时就会是 8

        # 将每个特征图展平,并沿通道维度拼接, 这里比较重要!
        x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
        # 这里的形状我们的x_cat = (batch_size, self.no=89, 20 x 20 + 40 x 40 + 80 x 80 = 8400)
        # 这里解释一下这个形状的含义(batch_size, 89, 8400)
        # 第一个batch_size就不解释了,大家都应该明白了就是我们训练过程中的batch的设置,代表一次将多少个图片输入给模型.
        # 第二个89我们应该也明白了前面也介绍了前 前16 x 4 包含的是位置信息 xywh, 后25包含的是我数据集中的类别信息
        # 需要注意的是8400这个概念大家可能不理解,因为我们是anchor free,我们需要对大中小三个检测头的每个像素都进行预测,前面说了20 x 20 + 40 x 40 + 80 x 80 = 8400
        # 那么这里就可以理解为我们对每个图片每个像素点都进行预测, 每个点都预测了16 x 4 包含的是位置信息 xywh, 后25包含的是我数据集中的类别信息, 然后我们共有batch_size个图片
        # 然后每个点都预测我们这里先不用理会这里就是单纯的预测结果!
        # (batch_size张图片, 89=各种信息, 8400=8400个像素点) 综合来翻译就是batch_size张图片,每个图片共产生了8400百个点的各种信息其中这些信息就是我们每个点的全部预测结果(不知道这么说大家能不能理解)

        if self.dynamic or self.shape != shape:
            # 生成 anchors 和 strides
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
            # 上面这一步的含义涉及到了代码make_anchors其返回的两个变量我们需要理解一下self.anchors, self.strides
            # self.ancors形状为(2, 8400)其中的含义是坐标信息,
            # self.strides为我们的比例信息每一个像素点的
            # 这里不理解没关系这段代码涉及到的make_anchors我下面会单独讲!
            # 其中的0.5为每个像素点举例像素中心的聚集为0.5个单元因为是正方形!
            self.shape = shape

        if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:  # 避免 TF FlexSplitV ops
            # 分割 box 和分类的输出
            box = x_cat[:, : self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4:]
        else:
            # 分割 box 和分类的输出(即将我们x_cat分为两部分一部分是前64个通道代表bbox信息, 后面的通道代表cls信息)
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
            # box形状为(1, 64, 8400)
            # box形状为(1, nc=25, 8400)
        if self.export and self.format in {"tflite", "edgetpu"}:
            # 预计算归一化因子以提高数值稳定性
            grid_h = shape[2]
            grid_w = shape[3]
            grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
            norm = self.strides / (self.stride[0] * grid_size)
            # 解码边界框
            dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
        else:
            # 解码边界框
            dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
            # dobx = (1, 4, 8400) 此时已经确定了预测的四个位置信息
        # 拼接解码后的边界框和分类概率,并返回
        y = torch.cat((dbox, cls.sigmoid()), 1)  
        # y = (1, 29, 8400) 4 + nc=25 =29
        # 此时这里就包含了真实的边界框信息 和 各个类别的分类概率!
        return y if self.export else (y, x)

    def bias_init(self):
        """Initialize Detect() biases, WARNING: requires stride availability."""
        m = self  # self.model[-1]  # Detect() module
        for a, b, s in zip(m.cv2, m.cv3, m.stride):  # 初始化每个检测层的偏置
            a[-1].bias.data[:] = 1.0  # box 偏置初始化为1.0
            b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # 类别偏置初始化,基于目标数量和图像大小

    def decode_bboxes(self, bboxes, anchors):
        """Decode bounding boxes."""
        return dist2bbox(bboxes, anchors, xywh=True, dim=1)

上面我们解释了YOLOv8检测头中的各个输出信息的含义,但是其中还有两个函数文章中是调用的并没有代码,分别是dist2bbox和make_anchors以及DFL,下面我们按照重要顺序来讲解!


2.2.1 DFL 

DFL是一个用于对象检测的损失函数模块,主要用于提高边界框回归的精度。它的核心思想是将每个预测的边界框参数(如 x, y, w, h)分解为多个通道,然后通过 softmax 操作得到一个分布,并计算分布的积分来预测实际值。

下面是部分代码解析!

class DFL(nn.Module):
    """
    Integral module of Distribution Focal Loss (DFL).

    Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
    """

    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

c1:表示输入通道数,默认为 16。
self.conv:定义了一个卷积层,输入通道数为 c1`,输出通道数为 1,卷积核大小为 1x1。这个卷积层的权重初始化为 0 到 c1-1 的浮点数,并且不更新requires_grad_(False)不更新梯度的意思代表。
self.c1:保存输入通道数。

def forward(self, x):
    """Applies a transformer layer on input tensor 'x' and returns a tensor."""
    b, _, a = x.shape  # batch, channels, anchors
    return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)

输入 x:形状为 [batch_size, channels, anchors]的张量,其中 channels为 4× c1,每个边界框参数(如 x, y, w, h)都有 c1个通道(这里和我们前面的解释一致)。

下面的操作为正向传播中最后一行代码的解析! 

1.变形操作: 

   x.view(b, 4, self.c1, a)

将 x变形为 [batch_size, 4, c1, anchors]的张量,其中 4 表示四个边界框参数(x, y, w, h),c1 是每个参数的通道数。

2. 转置操作:

.transpose(2, 1)

将张量转置为 [batch_size, c1, 4, anchors],将通道维度 c1 和参数维度 4交换。

3. softmax 操作:

   .softmax(1)

在通道维度 c1上应用 softmax,得到每个参数的概率分布。

4. 卷积操作:

self.conv(...)

使用卷积层 self.conv将分布的积分计算出来。卷积层的权重初始化为 0 到 c1-1,这样卷积操作实际上计算的是这些通道的加权平均值。

5. 变形回原始形状:

.view(b, 4, a)

最终将输出变形为 [batch_size, 4, anchors],即每个锚点的四个边界框参数。

具体操作流程

1. 输入形状:假设输入张量 x的形状为 [2, 64, 100](batch_size=2,channels=64,anchors=100),其中 64 是 4 × c1(每个参数 16 个通道)。
2. 变形:将输入张量变形为 [2, 4, 16, 100],表示每个锚点的四个参数(x, y, w, h),每个参数有 16 个通道。
3. 转置:将张量转置为 [2, 16, 4, 100]。
4. softmax:在通道维度 16 上应用 softmax,得到每个参数的概率分布。
5. 卷积:使用卷积层计算分布的积分,得到每个参数的加权平均值。
6. 输出形状:将输出变形为 [2, 4, 100],即每个锚点的四个边界框参数。

DFL 模块通过将每个边界框参数分解为多个通道,使用 softmax 获得概率分布,然后通过卷积计算分布的积分来预测实际的边界框参数。这样的方法可以提高边界框回归的精度。


2.2.2 make_anchors

 下面的代码是我们检测头代码中详解没有解析的内容大家可以去上面找一下!

def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)

 好的,让我们具体解释一下每个 anchor 如何表示特征图中每个单元格的位置(需要先理解的概念)。

什么是 Anchor

在对象检测中,anchor 是一个预定义的边界框模板,用于在特征图的每个单元格(即特征图的每个位置)上进行预测。每个 anchor 具有固定的大小和形状,模型通过调整这些 anchors 来拟合实际的物体边界框。

特征图和输入图像

假设我们有一个输入图像和对应的特征图:

  • 输入图像大小:640x640
  • 特征图大小:80x80

特征图的每个单元格对应于输入图像的一个区域。具体来说,特征图的一个单元格覆盖输入图像的 8x8 像素(假设 stride 为 8,前面讲过如何计算就是放缩比例640 ÷ 80 = 8)。

Anchor 在特征图中的位置

当我们在特征图上生成 anchors 时,每个单元格中心都会有一个 anchor。anchor 的位置用单元格的坐标表示(例如,特征图的第 (i, j) 个单元格,这些其实是图像基础的内容)。### 具体实现

在 make_anchors函数中,我们通过以下步骤生成 anchors(代码中内容的解析):

1. 生成网格点坐标

sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
  • torch.arange(end=w, device=device, dtype=dtype)生成从 0 到 w-1 的序列,表示特征图宽度方向上的索引。
  • torch.arange(end=h, device=device, dtype=dtype)生成从 0 到 h-1 的序列,表示特征图高度方向上的索引。
  • grid_cell_offset(默认为 0.5)用于将网格点移动到单元格的中心。

2. 生成网格坐标的所有组合

sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
  • torch.meshgrid(sy, sx)生成网格点的所有组合,即每个单元格的中心坐标。

3. 组合并展平

anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
  • torch.stack((sx, sy), -1)将 x 和 y 坐标堆叠在一起,形成形状为 (h, w, 2)的张量。
  • .view(-1, 2) 将其展平为形状为 (h * w, 2)的二维张量,其中每一行表示一个 anchor 的中心点坐标。

例子说明(用具体例子给大家阐述一下辅助大家理解一下)

假设我们有一个 3x3 的特征图,其单元格的中心坐标如下:
(0.5, 0.5)   (1.5, 0.5)   (2.5, 0.5)
(0.5, 1.5)   (1.5, 1.5)   (2.5, 1.5)
(0.5, 2.5)   (1.5, 2.5)   (2.5, 2.5)

使用 make_anchors函数生成的 anchor points 将是一个形状为 (9, 2)的张量 (数学内容),每一行表示一个单元格的中心坐标(其中的.5为偏移量!):
[[0.5, 0.5],
 [1.5, 0.5],
 [2.5, 0.5],
 [0.5, 1.5],
 [1.5, 1.5],
 [2.5, 1.5],
 [0.5, 2.5],
 [1.5, 2.5],
 [2.5, 2.5]]

总结

每个 anchor 表示特征图中每个单元格的位置,这些位置是特征图网格点的中心坐标。通过 make_anchors函数,我们生成了这些 anchors,并将它们用于对象检测模型的边界框预测,同时还有sride为放缩比例同时计算。


2.2.3 dist2bbox

下面的代码为位置坐标信息的解码操作!

def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    lt, rb = distance.chunk(2, dim)
    x1y1 = anchor_points - lt
    x2y2 = anchor_points + rb
    if xywh:
        c_xy = (x1y1 + x2y2) / 2
        wh = x2y2 - x1y1
        return torch.cat((c_xy, wh), dim)  # xywh bbox
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox

这段代码的作用是将距离转换为边界框坐标,可以选择转换为中心点坐标和宽高(xywh)或者左上角和右下角坐标(xyxy)。

def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
  • distance:表示从 anchor points 到边界框的左、上、右、下边界的距离(left, top, right, bottom)。
  • anchor_points:anchor points 的坐标。
  • xywh:布尔值,决定输出边界框的格式,默认为 `True,表示输出为中心点坐标和宽高(xywh),否则输出为左上角和右下角坐标(xyxy)。
  • dim:指明在哪个维度上切分距离张量,默认值为 -1。

核心逻辑

1. 切分距离张量

lt, rb = distance.chunk(2, dim)

将 distance 张量分为左右两个部分:

  • lt:left, top 距离
  • rb:right, bottom 距离
  • chunk(2, dim) 表示在指定维度 dim 上将 distance切分为 2 个张量。

2. 计算左上角和右下角坐标

x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
  • x1y1:左上角坐标,由 anchor points 减去 left 和 top 距离得到。
  • x2y2:右下角坐标,由 anchor points 加上 right 和 bottom 距离得到。

 3. 转换为指定格式的边界框

 如果 xywh 为 True,转换为中心点坐标和宽高:

if xywh:
    c_xy = (x1y1 + x2y2) / 2
    wh = x2y2 - x1y1
    return torch.cat((c_xy, wh), dim)  # xywh bbox
  • c_xy:中心点坐标,由左上角和右下角坐标的平均值计算得到。
  • wh:宽高,由右下角坐标减去左上角坐标计算得到。

最终返回中心点坐标和宽高的组合。

如果 xywh为 False,返回左上角和右下角坐标:

return torch.cat((x1y1, x2y2), dim)  # xyxy bbox

最终返回左上角和右下角坐标的组合。

示例

假设 distance为一个形状为 [N, 4] 的张量,其中 N是 anchor points 的数量,4表示 left, top, right, bottom 距离。anchor_points为一个形状为 [N, 2]的张量,表示每个 anchor 的坐标。

例子:转换为 xywh

distance = torch.tensor([[1, 1, 2, 2], [2, 2, 3, 3]])
anchor_points = torch.tensor([[5, 5], [10, 10]])
xywh = True

bbox = dist2bbox(distance, anchor_points, xywh)

步骤:
1. 切分 distance为 lt和 rb:

  • lt:[[1, 1], [2, 2]]
  • rb:[[2, 2], [3, 3]]

2. 计算 x1y1和 x2y2:

  • x1y1:[[4, 4], [8, 8]]
  • x2y2:[[7, 7], [13, 13]]

3. 转换为中心点坐标和宽高:

  • c_xy:[[5.5, 5.5], [10.5, 10.5]]
  • wh:[[3, 3], [5, 5]]

4. 返回的 bbox为 [[5.5, 5.5, 3, 3], [10.5, 10.5, 5, 5]]

例子:转换为 xyxy

xywh = False

bbox = dist2bbox(distance, anchor_points, xywh)

步骤:
1. 切分 distance为 lt 和rb:

  • lt:[[1, 1], [2, 2]]
  • rb:[[2, 2], [3, 3]]

2. 计算 x1y1和 x2y2:

  • x1y1:[[4, 4], [8, 8]]
  • x2y2:[[7, 7], [13, 13]]

3. 返回的bbox为 [[4, 4, 7, 7], [8, 8, 13, 13]]

总结

dist2bbox`函数将表示 left, top, right, bottom 距离的张量转换为边界框坐标,可以选择转换为中心点坐标和宽高(xywh)或者左上角和右下角坐标(xyxy)。通过这个函数,模型可以从预测的距离值中计算出实际的边界框坐标。

本文内容到此结束,本文为检测头的分析下一章节将推出损失函数的计算解析!


三、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

 专栏回顾:YOLOv11改进系列专栏——本专栏持续复习各种顶会内容——科研必备

d2e5d4828bd84bc79d11a9bd3ef13a35.png

猜你喜欢

转载自blog.csdn.net/java1314777/article/details/143100722