实现YOLO V3数据加载器:从文件系统读取图像与标签。。

引言

在深度学习项目中,数据准备是非常重要的一环。特别是在物体检测任务中,数据的组织和预处理直接影响到模型的训练效果。YOLO V3(You Only Look Once Version 3)作为一种高效的实时物体检测框架,其数据加载器的设计对于确保模型训练的顺利进行至关重要。本文将详细介绍如何使用Python和PyTorch实现一个YOLO V3的数据加载器,以支持从文件系统中读取图像及其对应的标签文件,并进行必要的预处理。

数据集组织

首先,我们需要了解数据集是如何组织的。通常情况下,图像数据集会被分成两个主要的部分:

  1. 图像文件:这些文件通常保存在磁盘上,并且每张图像都有一个对应的文件名。
  2. 标签文件:与图像文件一一对应,每个标签文件记录了图像中物体的位置和类别信息。

标签文件通常是以文本格式存储的,每一行代表一个物体的边界框信息,格式为 类别 中心点X 中心点Y 宽度 高度。这些值通常是归一化的,即相对于图像的宽度和高度而言。

数据加载器实现

为了实现YOLO V3的数据加载器,我们需要创建一个继承自torch.utils.data.Dataset的类,并重写其__init____len____getitem__方法。此外,我们还需要定义一些辅助函数来处理图像的预处理工作,如填充、调整大小等。

导入必要的库

import glob
import random
import os
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms

这段代码导入了构建数据加载器所需的库,包括文件处理、图像处理、张量操作等。

辅助函数定义

pad_to_square 函数
# 定义将图像填充为正方形的函数
def pad_to_square(img, pad_value):
    # 获取图像的通道数、高度和宽度
    c, h, w = img.shape
    # 计算高度和宽度的差值
    dim_diff = np.abs(h - w)
    # 计算需要在较短边填充的数量
    pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
    # 根据高度和宽度确定填充的方向(左侧/右侧 或 上侧/下侧)
    pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
    # 对图像进行填充
    img = F.pad(img, pad, "constant", value=pad_value)
    return img, pad

此函数将图像填充为正方形,以便后续处理。它计算图像的高度和宽度之差,并据此决定在哪个方向上添加填充。

resize 函数

# 定义调整图像大小的函数
def resize(image, size):
    # 使用最近邻插值调整图像大小
    image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
    return image

此函数将图像调整为指定的大小。它首先将图像的维度扩展到 (1, C, H, W),然后使用最近邻插值进行缩放,最后再压缩回原来的维度。

数据集类定义

ListDataset 类
# 定义数据集类
class ListDataset(Dataset):
    # 初始化方法
    def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
        # 读取包含图像路径的列表文件
        with open(list_path, "r") as file:
            self.img_files = file.readlines()
        # 根据图像路径找到对应的标签文件路径
        self.label_files = [
            path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
            for path in self.img_files
        ]
        # 设置图像大小
        self.img_size = img_size
        # 最大对象数量
        self.max_objects = 100
        # 是否启用数据增强
        self.augment = augment
        # 是否启用多尺度训练
        self.multiscale = multiscale
        # 标签是否已经归一化
        self.normalized_labels = normalized_labels
        # 设置图像大小的最小值和最大值
        self.min_size = self.img_size - 3 * 32
        self.max_size = self.img_size + 3 * 32
        # 用于追踪批次计数
        self.batch_count = 0

这是数据集类的初始化方法,它读取包含图像路径的列表文件,并根据图像路径找到对应的标签文件路径。

__getitem__ 方法
     # 获取数据集中指定索引的项目
    def __getitem__(self, index):
        # 获取图像路径
        img_path = self.img_files[index % len(self.img_files)].rstrip()
        # 读取图像并转换为张量
        img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
        
        # 如果图像不是三通道,则扩展为三通道
        if len(img.shape) != 3:
            img = img.unsqueeze(0)
            img = img.expand((3, img.shape[1:]))
        
        # 获取图像的高度和宽度
        _, h, w = img.shape
        # 计算填充因子
        h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
        # 将图像填充为正方形
        img, pad = pad_to_square(img, 0)
        # 获取填充后的图像的高度和宽度
        _, padded_h, padded_w = img.shape
        
        # 获取标签文件路径
        label_path = self.label_files[index % len(self.img_files)].rstrip()
        targets = None
        if os.path.exists(label_path):
            # 读取标签文件
            boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
            # 计算边界框的真实坐标
            x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
            y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
            x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
            y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
            # 考虑到填充的影响,调整边界框坐标
            x1 += pad[0]
            y1 += pad[2]
            x2 += pad[1]
            y2 += pad[3]
            # 重新归一化边界框坐标
            boxes[:, 1] = ((x1 + x2) / 2) / padded_w
            boxes[:, 2] = ((y1 + y2) / 2) / padded_h
            boxes[:, 3] *= w_factor / padded_w
            boxes[:, 4] *= h_factor / padded_h
            # 创建目标张量
            targets = torch.zeros((len(boxes), 6))
            targets[:, 1:] = boxes
        
        # 应用数据增强
        if self.augment and np.random.random() < 0.5:
            img, targets = horisontal_flip(img, targets)
        
        return img_path, img, targets

此方法用于获取数据集中单个样本。它读取图像,进行必要的预处理(如转换为张量、填充至正方形、调整大小),并读取对应的标签文件,调整边界框坐标以适应图像处理后的变化。

collate_fn 方法
  # 用于处理一批数据的方法
    def collate_fn(self, batch):
        # 解压批次数据
        paths, imgs, targets = zip(*batch)
        # 移除空的占位标签
        targets = [boxes for boxes in targets if boxes is not None]
        # 给每个目标添加样本索引
        for i, boxes in enumerate(targets):
            boxes[:, 0] = i
        # 合并所有目标
        targets = torch.cat(targets, 0)
        # 每十个批次选择一个新的图像大小
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
        # 调整图像大小
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])
        # 更新批次计数
        self.batch_count += 1
        return paths, imgs, targets

此方法用于处理从数据集中获取的一批数据。它合并不同样本的标签,并根据需要调整图像大小。

__len__ 方法
    # 返回数据集中样本的数量
    def __len__(self):
        return len(self.img_files)

返回数据集中样本的数量。

猜你喜欢

转载自blog.csdn.net/m0_73697499/article/details/143418535