引言
在深度学习项目中,数据准备是非常重要的一环。特别是在物体检测任务中,数据的组织和预处理直接影响到模型的训练效果。YOLO V3(You Only Look Once Version 3)作为一种高效的实时物体检测框架,其数据加载器的设计对于确保模型训练的顺利进行至关重要。本文将详细介绍如何使用Python和PyTorch实现一个YOLO V3的数据加载器,以支持从文件系统中读取图像及其对应的标签文件,并进行必要的预处理。
数据集组织
首先,我们需要了解数据集是如何组织的。通常情况下,图像数据集会被分成两个主要的部分:
- 图像文件:这些文件通常保存在磁盘上,并且每张图像都有一个对应的文件名。
- 标签文件:与图像文件一一对应,每个标签文件记录了图像中物体的位置和类别信息。
标签文件通常是以文本格式存储的,每一行代表一个物体的边界框信息,格式为 类别 中心点X 中心点Y 宽度 高度
。这些值通常是归一化的,即相对于图像的宽度和高度而言。
数据加载器实现
为了实现YOLO V3的数据加载器,我们需要创建一个继承自torch.utils.data.Dataset
的类,并重写其__init__
、__len__
和__getitem__
方法。此外,我们还需要定义一些辅助函数来处理图像的预处理工作,如填充、调整大小等。
导入必要的库
import glob
import random
import os
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as transforms
这段代码导入了构建数据加载器所需的库,包括文件处理、图像处理、张量操作等。
辅助函数定义
pad_to_square
函数
# 定义将图像填充为正方形的函数
def pad_to_square(img, pad_value):
# 获取图像的通道数、高度和宽度
c, h, w = img.shape
# 计算高度和宽度的差值
dim_diff = np.abs(h - w)
# 计算需要在较短边填充的数量
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
# 根据高度和宽度确定填充的方向(左侧/右侧 或 上侧/下侧)
pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
# 对图像进行填充
img = F.pad(img, pad, "constant", value=pad_value)
return img, pad
此函数将图像填充为正方形,以便后续处理。它计算图像的高度和宽度之差,并据此决定在哪个方向上添加填充。
resize
函数
# 定义调整图像大小的函数
def resize(image, size):
# 使用最近邻插值调整图像大小
image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
return image
此函数将图像调整为指定的大小。它首先将图像的维度扩展到 (1, C, H, W)
,然后使用最近邻插值进行缩放,最后再压缩回原来的维度。
数据集类定义
ListDataset
类
# 定义数据集类
class ListDataset(Dataset):
# 初始化方法
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
# 读取包含图像路径的列表文件
with open(list_path, "r") as file:
self.img_files = file.readlines()
# 根据图像路径找到对应的标签文件路径
self.label_files = [
path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
for path in self.img_files
]
# 设置图像大小
self.img_size = img_size
# 最大对象数量
self.max_objects = 100
# 是否启用数据增强
self.augment = augment
# 是否启用多尺度训练
self.multiscale = multiscale
# 标签是否已经归一化
self.normalized_labels = normalized_labels
# 设置图像大小的最小值和最大值
self.min_size = self.img_size - 3 * 32
self.max_size = self.img_size + 3 * 32
# 用于追踪批次计数
self.batch_count = 0
这是数据集类的初始化方法,它读取包含图像路径的列表文件,并根据图像路径找到对应的标签文件路径。
__getitem__
方法
# 获取数据集中指定索引的项目
def __getitem__(self, index):
# 获取图像路径
img_path = self.img_files[index % len(self.img_files)].rstrip()
# 读取图像并转换为张量
img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
# 如果图像不是三通道,则扩展为三通道
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand((3, img.shape[1:]))
# 获取图像的高度和宽度
_, h, w = img.shape
# 计算填充因子
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
# 将图像填充为正方形
img, pad = pad_to_square(img, 0)
# 获取填充后的图像的高度和宽度
_, padded_h, padded_w = img.shape
# 获取标签文件路径
label_path = self.label_files[index % len(self.img_files)].rstrip()
targets = None
if os.path.exists(label_path):
# 读取标签文件
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
# 计算边界框的真实坐标
x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
# 考虑到填充的影响,调整边界框坐标
x1 += pad[0]
y1 += pad[2]
x2 += pad[1]
y2 += pad[3]
# 重新归一化边界框坐标
boxes[:, 1] = ((x1 + x2) / 2) / padded_w
boxes[:, 2] = ((y1 + y2) / 2) / padded_h
boxes[:, 3] *= w_factor / padded_w
boxes[:, 4] *= h_factor / padded_h
# 创建目标张量
targets = torch.zeros((len(boxes), 6))
targets[:, 1:] = boxes
# 应用数据增强
if self.augment and np.random.random() < 0.5:
img, targets = horisontal_flip(img, targets)
return img_path, img, targets
此方法用于获取数据集中单个样本。它读取图像,进行必要的预处理(如转换为张量、填充至正方形、调整大小),并读取对应的标签文件,调整边界框坐标以适应图像处理后的变化。
collate_fn
方法
# 用于处理一批数据的方法
def collate_fn(self, batch):
# 解压批次数据
paths, imgs, targets = zip(*batch)
# 移除空的占位标签
targets = [boxes for boxes in targets if boxes is not None]
# 给每个目标添加样本索引
for i, boxes in enumerate(targets):
boxes[:, 0] = i
# 合并所有目标
targets = torch.cat(targets, 0)
# 每十个批次选择一个新的图像大小
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
# 调整图像大小
imgs = torch.stack([resize(img, self.img_size) for img in imgs])
# 更新批次计数
self.batch_count += 1
return paths, imgs, targets
此方法用于处理从数据集中获取的一批数据。它合并不同样本的标签,并根据需要调整图像大小。
__len__
方法
# 返回数据集中样本的数量
def __len__(self):
return len(self.img_files)
返回数据集中样本的数量。