torch.utils.data.Dataset()
和torch.utils.data.DataLoader()
是Pytorch
中处理数据集和批量加载数据的重要工具。下面将详细介绍它们的作用、用法,并通过一个简单的例子来演示如何使用它们。
torch.utils.data.Dataset()
Dataset
是Pytorch
数据加载的基类,用于表示一个数据集。用户可以继承Dataset
类并实现其两个方法:__len__()
和__getitem__()
,这两个方法分别用于返回数据集的大小和获取数据集中的样本。
主要方法:
__len__(self):
返回数据集的样本数__getitem__(self,index):
给定一个索引,返回该索引对应的数据
示例代码:自定义数据集
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
"""
初始化数据集
:param data: 数据,通常是一个tensor或者numpy数组
:param labels: 标签,通常是一个tensor或者numpy数组
"""
self.data = data
self.labels = labels
def __len__(self):
"""
返回数据集的样本数量
"""
return len(self.data)
def __getitem__(self, idx):
"""
获取一个样本和对应的标签
:param idx: 数据的索引
:return: 一个包含数据和标签的元组
"""
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 创建一个简单的示例数据集
data = torch.randn(100, 3, 64, 64) # 100个64x64的RGB图像
labels = torch.randint(0, 2, (100,)) # 100个二分类标签
# 创建数据集实例
dataset = MyDataset(data, labels)
# 查看数据集的长度和第一个样本
print(f"Dataset length: {
len(dataset)}")
sample, label = dataset[0]
print(f"Sample shape: {
sample.shape}, Label: {
label}")
torch.utils.data.DataLoader
DataLoader
是一个用于批量加载数据的工具,它可以自动地将数据集切分为批次、打乱数据、并支持多线程加载等功能。DataLoader
可以传入一个Dataset对象来加载数据。
主要参数:
dataset
:数据集,通常是一个继承自Dataset
的对象batch_size
:每个批次的数据量(样本数)shuffle
:是否在每个epoch
后打乱数据,默认是False
num_workers
:加载数据时使用的子进程数,默认是0,即不使用多进程drop_last
: 如果数据集大小不是batch_size
的整数倍,是否丢弃最后一个不完整的批次,默认是False
。
示例代码:使用DataLoader
加载数据
from torch.utils.data import DataLoader
# 创建 DataLoader 实例
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)
# 迭代 DataLoader 获取批次数据
for batch_idx, (data, labels) in enumerate(dataloader):
print(f"Batch {
batch_idx+1}:")
print(f"Data shape: {
data.shape}, Labels shape: {
labels.shape}")
完整示例:自定义数据集+DataLoader
下面是一个包含自定义数据集和 DataLoader
的完整示例。我们使用简单的随机生成数据,模拟图像分类任务。
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
# Step 1: 自定义数据集
class RandomDataset(Dataset):
def __init__(self, num_samples, image_size, num_classes):
"""
初始化数据集
:param num_samples: 样本数量
:param image_size: 图像尺寸,例如(3, 64, 64)表示3通道64x64的图像
:param num_classes: 类别数量
"""
self.num_samples = num_samples
self.image_size = image_size
self.num_classes = num_classes
# 生成随机数据和标签
self.data = torch.randn(num_samples, *image_size) # 随机图像数据
self.labels = torch.randint(0, num_classes, (num_samples,)) # 随机标签
def __len__(self):
"""返回数据集的大小"""
return self.num_samples
def __getitem__(self, idx):
"""获取一个样本和标签"""
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# Step 2: 创建数据集实例
dataset = RandomDataset(num_samples=100, image_size=(3, 64, 64), num_classes=10)
# Step 3: 创建 DataLoader 实例
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)
# Step 4: 迭代 DataLoader 获取批次数据
for batch_idx, (data, labels) in enumerate(dataloader):
print(f"Batch {
batch_idx+1}:")
print(f"Data shape: {
data.shape}, Labels shape: {
labels.shape}")
if batch_idx >= 2: # 仅打印前3个批次的数据
break
解释
:
- 自定义数据集 (
RandomDataset
):__init__(self, num_samples, image_size, num_classes)
:这个初始化方法接收样本数量、图像尺寸和类别数,并随机生成图像数据和对应的标签。__len__(self)
:返回数据集的大小,即样本数量。__getitem__(self, idx)
:根据索引返回数据集中的某一项数据和标签。
DataLoader
:batch_size=16
:每个批次包含 16 个样本。shuffle=True
:每次迭代时打乱数据。num_workers=2
:使用 2 个进程来加载数据(可以加速数据加载)。
- 迭代
DataLoader
:- 使用
for
循环遍历DataLoader
,每次返回一个批次的数据和标签,数据和标签是一个元组(data, labels)
。 - 在每个批次中,
data
是一个形状为(16, 3, 64, 64)
的张量,表示 16 张 3 通道 64x64 的图像。 labels
是一个形状为(16,)
的张量,表示每张图像的标签。
- 使用
总结
:
Dataset
类是PyTorch
数据加载的基础,用户需要继承它并实现__len__
和__getitem__
方法来自定义数据集。DataLoader
是一个方便的工具,用于批量加载数据,支持批量处理、数据打乱、并行加载等功能。- 在使用
DataLoader
时,数据集会被分成多个批次,每次迭代时返回一个批次的数据。
这些工具对于处理大型数据集非常有用,能够有效地提高模型训练的效率。