CIFAR10 数据集自定义处理方法

CIFAR10 数据集自定义处理方法

可以自定义训练集和测试集中不同类别的样本的数量。可用于模拟类别不平衡问题,存在混淆数据问题。

import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random

# 自定义数据集类,继承自 torch.utils.data.Dataset
class CustomCIFAR10Dataset(Dataset):
    def __init__(self, images, labels, transform=None):
        """
        自定义数据集类
        :param images: 图像数据,numpy 数组格式
        :param labels: 标签数据,numpy 数组格式
        :param transform: 可选的图像预处理转换
        """
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        image = self.images[index]
        if self.transform:
            image = self.transform(image)
        label = self.labels[index]
        return image, label

def create_custom_dataset(positive_classes, negative_classes, sample_counts=None, transform=None, train=True):
    """
    创建自定义数据集(训练集或测试集)
    :param positive_classes: 正类别的类别列表
    :param negative_classes: 负类别的类别列表
    :param sample_counts: 每个类别的样本数量限制,字典形式 {类: 样本数量}
    :param transform: 图像预处理转换
    :param train: 是否是训练集(True)还是测试集(False)
    :return: 创建的自定义数据集(CustomCIFAR10Dataset)和原始数据集
    """
    # 下载 CIFAR-10 数据集(训练集或测试集)
    dataset = dsets.CIFAR10(root='./data', train=train, download=True, transform=transforms.ToTensor())
    images = dataset.data  # numpy array, shape [N, 32, 32, 3]
    targets = np.array(dataset.targets)  # shape [N]
    
    new_images = []
    new_labels = []
    selected_global_indices = []

    for cls in np.concatenate((positive_classes, negative_classes)):
        # 获取当前类别的样本索引
        indices = np.where(targets == cls)[0]
        
        # 如果有样本数量限制,则抽取样本
        if sample_counts is not None and cls in sample_counts:
            num_samples = min(sample_counts[cls], len(indices))
            selected_indices = np.random.choice(indices, num_samples, replace=False)
        else:
            selected_indices = indices
        
        selected_global_indices.extend(selected_indices.tolist())
        
        # 为正类别标签为1,负类别标签为0
        for idx in selected_indices:
            new_images.append(images[idx])
            if cls in positive_classes:
                new_labels.append(1)
            else:
                new_labels.append(0)

    # 转换为 numpy 数组
    new_images = np.array(new_images)
    new_labels = np.array(new_labels)
    
    # 打乱新数据集
    perm = np.random.permutation(len(new_labels))
    new_images = new_images[perm]
    new_labels = new_labels[perm]
    
    # 创建自定义数据集
    custom_dataset = CustomCIFAR10Dataset(new_images, new_labels, transform=transform)
    return custom_dataset, dataset

if __name__ == '__main__':
    # 定义正类别和负类别
    positive_classes = [0, 1, 2, 3, 4]
    negative_classes = [5, 6, 7, 8, 9]
    
    # 定义每个类别需要抽取的样本数量
    sample_counts = {
    
    0: 500, 1: 500, 2: 500, 3: 500, 4: 500, 5: 500, 6: 500, 7: 500, 8: 500, 9: 500}
    
    # 图像预处理
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    # 训练数据集
    train_dataset, base_train_dataset = create_custom_dataset(positive_classes, negative_classes, sample_counts, transform, train=True)
    print('Training dataset size:', len(train_dataset))

    # 测试数据集
    positive_classes_test = [0]
    negative_classes_test = [5, 6, 7, 8, 9]
    sample_counts_test = {
    
    0: 1000, 5: 500, 6: 500, 7: 500, 8: 500, 9: 500}
    
    test_dataset, base_test_dataset = create_custom_dataset(positive_classes_test, negative_classes_test, sample_counts_test, transform, train=False)
    print('Test dataset size:', len(test_dataset))

    # 使用 DataLoader 加载数据集
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # 打印加载器中的数据量
    for images, labels in train_loader:
        print(f"Batch size: {
      
      len(images)}, Labels: {
      
      labels}")
        break

代码详细解释文档

1. 自定义数据集类 CustomCIFAR10Dataset

此类继承自 torch.utils.data.Dataset,用于自定义数据集的管理,具体功能如下:

  • __init__: 初始化方法,接受图像数据、标签数据和可能的图像预处理变换。
  • __len__: 返回数据集的长度,即样本数量。
  • __getitem__: 根据索引返回样本图像和标签,若定义了预处理变换,则应用该变换。
2. create_custom_dataset 函数

此函数用于创建训练集或测试集,并按类别划分和抽样。

  • positive_classes: 正类别的类别列表,标签为 1。
  • negative_classes: 负类别的类别列表,标签为 0。
  • sample_counts: 可选,字典形式,指定每个类别的样本数量限制。如果没有该参数,则使用所有样本。
  • transform: 可选,图像预处理变换。
  • train: 是否为训练集。如果为 True,则加载训练集;如果为 False,则加载测试集。
3. 数据集的处理流程
  • 从 CIFAR-10 下载训练集或测试集,获取图像数据和标签。
  • 根据给定的类别信息,抽取所需类别的图像样本,并为正类分配标签为 1,负类分配标签为 0。
  • 如果有样本数量限制,则从每个类别中随机选择样本。
  • 将抽取的图像和标签打乱顺序,并创建自定义数据集 CustomCIFAR10Dataset
4. 训练集和测试集的使用

在主程序中:

  1. 定义正类别和负类别,以及每个类别的样本数量限制。
  2. 使用 create_custom_dataset 创建训练集和测试集。
  3. 使用 DataLoader 加载数据集,设置批次大小并进行数据打乱。
5. DataLoader 的使用
  • DataLoader 用于加载训练数据,并将其按批次处理。我们将自定义数据集传入 DataLoader 并设置批次大小为 64。
  • 在循环中,打印每个批次的大小和标签信息。
6. 输出示例

运行此代码时,您将看到类似以下的输出:

Training dataset size: 5000
Test dataset size: 3500
Batch size: 64, Labels: tensor([1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1])

优化说明

  • 代码中使用了 np.random.permutation 来打乱数据集的顺序,确保数据的随机性。
  • 自定义数据集和图像预处理功能让代码具有灵活性,能够方便地处理不同任务的需求。
  • 使用 DataLoader 来批量加载数据,提升训练效率。