torch.utils.data.Dataset()和torch.utils.data.DataLoader()

torch.utils.data.Dataset()torch.utils.data.DataLoader()Pytorch中处理数据集和批量加载数据的重要工具。下面将详细介绍它们的作用、用法,并通过一个简单的例子来演示如何使用它们。

torch.utils.data.Dataset()

DatasetPytorch数据加载的基类,用于表示一个数据集。用户可以继承Dataset类并实现其两个方法:__len__()__getitem__(),这两个方法分别用于返回数据集的大小和获取数据集中的样本。

主要方法:

  • __len__(self):返回数据集的样本数
  • __getitem__(self,index):给定一个索引,返回该索引对应的数据

示例代码:自定义数据集

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        """
        初始化数据集
        :param data: 数据,通常是一个tensor或者numpy数组
        :param labels: 标签,通常是一个tensor或者numpy数组
        """
        self.data = data
        self.labels = labels

    def __len__(self):
        """
        返回数据集的样本数量
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        获取一个样本和对应的标签
        :param idx: 数据的索引
        :return: 一个包含数据和标签的元组
        """
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label


# 创建一个简单的示例数据集
data = torch.randn(100, 3, 64, 64)  # 100个64x64的RGB图像
labels = torch.randint(0, 2, (100,))  # 100个二分类标签

# 创建数据集实例
dataset = MyDataset(data, labels)

# 查看数据集的长度和第一个样本
print(f"Dataset length: {
      
      len(dataset)}")
sample, label = dataset[0]
print(f"Sample shape: {
      
      sample.shape}, Label: {
      
      label}")

torch.utils.data.DataLoader

DataLoader是一个用于批量加载数据的工具,它可以自动地将数据集切分为批次、打乱数据、并支持多线程加载等功能。DataLoader可以传入一个Dataset对象来加载数据。

主要参数:

  • dataset:数据集,通常是一个继承自Dataset的对象
  • batch_size:每个批次的数据量(样本数)
  • shuffle:是否在每个epoch后打乱数据,默认是False
  • num_workers:加载数据时使用的子进程数,默认是0,即不使用多进程
  • drop_last: 如果数据集大小不是 batch_size 的整数倍,是否丢弃最后一个不完整的批次,默认是 False

示例代码:使用DataLoader加载数据

from torch.utils.data import DataLoader

# 创建 DataLoader 实例
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)

# 迭代 DataLoader 获取批次数据
for batch_idx, (data, labels) in enumerate(dataloader):
    print(f"Batch {
      
      batch_idx+1}:")
    print(f"Data shape: {
      
      data.shape}, Labels shape: {
      
      labels.shape}")
完整示例:自定义数据集+DataLoader

下面是一个包含自定义数据集和 DataLoader 的完整示例。我们使用简单的随机生成数据,模拟图像分类任务。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

# Step 1: 自定义数据集
class RandomDataset(Dataset):
    def __init__(self, num_samples, image_size, num_classes):
        """
        初始化数据集
        :param num_samples: 样本数量
        :param image_size: 图像尺寸,例如(3, 64, 64)表示3通道64x64的图像
        :param num_classes: 类别数量
        """
        self.num_samples = num_samples
        self.image_size = image_size
        self.num_classes = num_classes
        
        # 生成随机数据和标签
        self.data = torch.randn(num_samples, *image_size)  # 随机图像数据
        self.labels = torch.randint(0, num_classes, (num_samples,))  # 随机标签

    def __len__(self):
        """返回数据集的大小"""
        return self.num_samples

    def __getitem__(self, idx):
        """获取一个样本和标签"""
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# Step 2: 创建数据集实例
dataset = RandomDataset(num_samples=100, image_size=(3, 64, 64), num_classes=10)

# Step 3: 创建 DataLoader 实例
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)

# Step 4: 迭代 DataLoader 获取批次数据
for batch_idx, (data, labels) in enumerate(dataloader):
    print(f"Batch {
      
      batch_idx+1}:")
    print(f"Data shape: {
      
      data.shape}, Labels shape: {
      
      labels.shape}")
    if batch_idx >= 2:  # 仅打印前3个批次的数据
        break
解释
  • 自定义数据集 (RandomDataset)
    • __init__(self, num_samples, image_size, num_classes):这个初始化方法接收样本数量、图像尺寸和类别数,并随机生成图像数据和对应的标签。
    • __len__(self):返回数据集的大小,即样本数量。
    • __getitem__(self, idx):根据索引返回数据集中的某一项数据和标签。
  • DataLoader
    • batch_size=16:每个批次包含 16 个样本。
    • shuffle=True:每次迭代时打乱数据。
    • num_workers=2:使用 2 个进程来加载数据(可以加速数据加载)。
  • 迭代 DataLoader
    • 使用 for 循环遍历 DataLoader,每次返回一个批次的数据和标签,数据和标签是一个元组 (data, labels)
    • 在每个批次中,data 是一个形状为 (16, 3, 64, 64) 的张量,表示 16 张 3 通道 64x64 的图像。
    • labels 是一个形状为 (16,) 的张量,表示每张图像的标签。
总结
  • Dataset 类是 PyTorch 数据加载的基础,用户需要继承它并实现 __len____getitem__ 方法来自定义数据集。
  • DataLoader 是一个方便的工具,用于批量加载数据,支持批量处理、数据打乱、并行加载等功能。
  • 在使用 DataLoader 时,数据集会被分成多个批次,每次迭代时返回一个批次的数据。

这些工具对于处理大型数据集非常有用,能够有效地提高模型训练的效率。