多分类自定义采样比例

多分类自定义采样比例

import torch
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder

# 假设你有一个自定义的数据集类
class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.dataset = ImageFolder(data_dir, transform=transform)
        self.class_weights = self.calculate_class_weights()

    def calculate_class_weights(self):
        # 计算每个类别的样本权重,可以根据不同的策略进行调整
        class_counts = torch.tensor([self.dataset.targets.count(i) for i in range(len(self.dataset.classes))])
        class_weights = 1.0 / class_counts
        return class_weights

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

# 数据集目录
data_dir = "path/to/your/dataset"

# 定义图像转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    tra

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/134563749