CIFAR10 数据集自定义处理方法
可以自定义训练集和测试集中不同类别的样本的数量。可用于模拟类别不平衡问题,存在混淆数据问题。
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
# 自定义数据集类,继承自 torch.utils.data.Dataset
class CustomCIFAR10Dataset(Dataset):
def __init__(self, images, labels, transform=None):
"""
自定义数据集类
:param images: 图像数据,numpy 数组格式
:param labels: 标签数据,numpy 数组格式
:param transform: 可选的图像预处理转换
"""
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
image = self.images[index]
if self.transform:
image = self.transform(image)
label = self.labels[index]
return image, label
def create_custom_dataset(positive_classes, negative_classes, sample_counts=None, transform=None, train=True):
"""
创建自定义数据集(训练集或测试集)
:param positive_classes: 正类别的类别列表
:param negative_classes: 负类别的类别列表
:param sample_counts: 每个类别的样本数量限制,字典形式 {类: 样本数量}
:param transform: 图像预处理转换
:param train: 是否是训练集(True)还是测试集(False)
:return: 创建的自定义数据集(CustomCIFAR10Dataset)和原始数据集
"""
# 下载 CIFAR-10 数据集(训练集或测试集)
dataset = dsets.CIFAR10(root='./data', train=train, download=True, transform=transforms.ToTensor())
images = dataset.data # numpy array, shape [N, 32, 32, 3]
targets = np.array(dataset.targets) # shape [N]
new_images = []
new_labels = []
selected_global_indices = []
for cls in np.concatenate((positive_classes, negative_classes)):
# 获取当前类别的样本索引
indices = np.where(targets == cls)[0]
# 如果有样本数量限制,则抽取样本
if sample_counts is not None and cls in sample_counts:
num_samples = min(sample_counts[cls], len(indices))
selected_indices = np.random.choice(indices, num_samples, replace=False)
else:
selected_indices = indices
selected_global_indices.extend(selected_indices.tolist())
# 为正类别标签为1,负类别标签为0
for idx in selected_indices:
new_images.append(images[idx])
if cls in positive_classes:
new_labels.append(1)
else:
new_labels.append(0)
# 转换为 numpy 数组
new_images = np.array(new_images)
new_labels = np.array(new_labels)
# 打乱新数据集
perm = np.random.permutation(len(new_labels))
new_images = new_images[perm]
new_labels = new_labels[perm]
# 创建自定义数据集
custom_dataset = CustomCIFAR10Dataset(new_images, new_labels, transform=transform)
return custom_dataset, dataset
if __name__ == '__main__':
# 定义正类别和负类别
positive_classes = [0, 1, 2, 3, 4]
negative_classes = [5, 6, 7, 8, 9]
# 定义每个类别需要抽取的样本数量
sample_counts = {
0: 500, 1: 500, 2: 500, 3: 500, 4: 500, 5: 500, 6: 500, 7: 500, 8: 500, 9: 500}
# 图像预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 训练数据集
train_dataset, base_train_dataset = create_custom_dataset(positive_classes, negative_classes, sample_counts, transform, train=True)
print('Training dataset size:', len(train_dataset))
# 测试数据集
positive_classes_test = [0]
negative_classes_test = [5, 6, 7, 8, 9]
sample_counts_test = {
0: 1000, 5: 500, 6: 500, 7: 500, 8: 500, 9: 500}
test_dataset, base_test_dataset = create_custom_dataset(positive_classes_test, negative_classes_test, sample_counts_test, transform, train=False)
print('Test dataset size:', len(test_dataset))
# 使用 DataLoader 加载数据集
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 打印加载器中的数据量
for images, labels in train_loader:
print(f"Batch size: {
len(images)}, Labels: {
labels}")
break
代码详细解释文档
1. 自定义数据集类 CustomCIFAR10Dataset
此类继承自 torch.utils.data.Dataset
,用于自定义数据集的管理,具体功能如下:
__init__
: 初始化方法,接受图像数据、标签数据和可能的图像预处理变换。__len__
: 返回数据集的长度,即样本数量。__getitem__
: 根据索引返回样本图像和标签,若定义了预处理变换,则应用该变换。
2. create_custom_dataset
函数
此函数用于创建训练集或测试集,并按类别划分和抽样。
positive_classes
: 正类别的类别列表,标签为 1。negative_classes
: 负类别的类别列表,标签为 0。sample_counts
: 可选,字典形式,指定每个类别的样本数量限制。如果没有该参数,则使用所有样本。transform
: 可选,图像预处理变换。train
: 是否为训练集。如果为True
,则加载训练集;如果为False
,则加载测试集。
3. 数据集的处理流程
- 从 CIFAR-10 下载训练集或测试集,获取图像数据和标签。
- 根据给定的类别信息,抽取所需类别的图像样本,并为正类分配标签为 1,负类分配标签为 0。
- 如果有样本数量限制,则从每个类别中随机选择样本。
- 将抽取的图像和标签打乱顺序,并创建自定义数据集
CustomCIFAR10Dataset
。
4. 训练集和测试集的使用
在主程序中:
- 定义正类别和负类别,以及每个类别的样本数量限制。
- 使用
create_custom_dataset
创建训练集和测试集。 - 使用
DataLoader
加载数据集,设置批次大小并进行数据打乱。
5. DataLoader
的使用
DataLoader
用于加载训练数据,并将其按批次处理。我们将自定义数据集传入DataLoader
并设置批次大小为 64。- 在循环中,打印每个批次的大小和标签信息。
6. 输出示例
运行此代码时,您将看到类似以下的输出:
Training dataset size: 5000
Test dataset size: 3500
Batch size: 64, Labels: tensor([1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1])
优化说明
- 代码中使用了
np.random.permutation
来打乱数据集的顺序,确保数据的随机性。 - 自定义数据集和图像预处理功能让代码具有灵活性,能够方便地处理不同任务的需求。
- 使用
DataLoader
来批量加载数据,提升训练效率。