ImageNet-like数据集划分与加载

1、简介

数据集通常要分为train、val和test,如果没有官方发布的标准分法,我们就需要自己对我们的数据集进行分割。

2、划分数据集

首先,我们的数据集应该满足以下结构,即所有的图片按类别放入相应的文件夹中。

在这里插入图片描述
那么划分代码如下:

import argparse
import random
import os
import os.path as osp


def is_pic(img_name):
    valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
    suffix = img_name.split('.')[-1]
    if suffix not in valid_suffix:
        return False
    return True


def list_files(dirname):
    """ 列出目录下所有文件(包括所属的一级子目录下文件)
    Args:
        dirname: 目录路径
    """

    def filter_file(f):
        if f.startswith('.'):
            return True
        return False

    all_files = list()
    dirs = list()
    for f in os.listdir(dirname):
        if filter_file(f):
            continue
        if osp.isdir(osp.join(dirname, f)):
            dirs.append(f)
        else:
            all_files.append(f)
    for d in dirs:
        for f in os.listdir(osp.join(dirname, d)):
            if filter_file(f):
                continue
            if osp.isdir(osp.join(dirname, d, f)):
                continue
            all_files.append(osp.join(d, f))
    return all_files


def split_imagenet_dataset(dataset_dir, val_percent, test_percent, save_dir):
    all_files = list_files(dataset_dir)
    label_list = list()
    train_image_anno_list = list()
    val_image_anno_list = list()
    test_image_anno_list = list()
    for file in all_files:
        if not is_pic(file):
            continue
        label, image_name = osp.split(file)
        if label not in label_list:
            label_list.append(label)
    label_list = sorted(label_list)

    for i in range(len(label_list)):
        image_list = list_files(osp.join(dataset_dir, label_list[i]))
        image_anno_list = list()
        for img in image_list:
            image_anno_list.append([osp.join(label_list[i], img), i])
        random.shuffle(image_anno_list)
        image_num = len(image_anno_list)
        val_num = int(image_num * val_percent)
        test_num = int(image_num * test_percent)
        train_num = image_num - val_num - test_num

        train_image_anno_list += image_anno_list[:train_num]
        val_image_anno_list += image_anno_list[train_num:train_num + val_num]
        test_image_anno_list += image_anno_list[train_num + val_num:]

    with open(
            osp.join(save_dir, 'train_list.txt'), mode='w',
            encoding='utf-8') as f:
        for x in train_image_anno_list:
            file, label = x
            f.write('{} {}\n'.format(file, label))
    with open(
            osp.join(save_dir, 'val_list.txt'), mode='w',
            encoding='utf-8') as f:
        for x in val_image_anno_list:
            file, label = x
            f.write('{} {}\n'.format(file, label))
    if len(test_image_anno_list):
        with open(
                osp.join(save_dir, 'test_list.txt'), mode='w',
                encoding='utf-8') as f:
            for x in test_image_anno_list:
                file, label = x
                f.write('{} {}\n'.format(file, label))
    with open(
            osp.join(save_dir, 'labels.txt'), mode='w', encoding='utf-8') as f:
        for l in sorted(label_list):
            f.write('{}\n'.format(l))

    return len(train_image_anno_list), len(val_image_anno_list), len(
        test_image_anno_list)


def dataset_split(dataset_dir, val_value, test_value, save_dir):
    print("Dataset split starts...")
    train_num, val_num, test_num = split_imagenet_dataset(
        dataset_dir, val_value, test_value, save_dir)
    print("Dataset split done.")
    print("Train samples: {}".format(train_num))
    print("Eval samples: {}".format(val_num))
    print("Test samples: {}".format(test_num))
    print("Split files saved in {}".format(save_dir))


def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument("--dataset_dir", required=True, help="数据集根目录")
    # parser.add_argument("--val_value", required=True, help="验证集比例")
    # parser.add_argument("--test_value", required=True, help="测试集比例")
    # parser.add_argument("--save_dir", required=True, help="划分文件保存地址")
    # args = parser.parse_args()
    dataset_dir=r"D:\WorkSpace\LittlePapersFlow\dataset\ckp\merge"
    val_value=0.3
    test_value=0
    save_dir=dataset_dir
    dataset_split(dataset_dir, val_value, test_value, save_dir)


main()

  • dataset_dir:数据集根目录,即上图”路径文件夹“的绝对路径
  • val_value:验证集所占比例(比如常见的7:3:0,验证集占0.3)
  • test_value:测试集所占比例
  • save_dir:生成的划分文件的保存地址,代码中会保存到”路径文件夹“中去,也可以指定其他路径

注意:训练集的比例=1-val_value-test_value

3、Pytorch中加载ImageNet-like数据集

直接放代码

import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
from torch.utils.data import Dataset
from PIL import Image


class FolderDatasetme(Dataset):
    def __init__(self, root, label_file, transform):
        self.root = root
        with open(os.path.join(root, label_file)) as f:
            self.info_list = f.readlines()
        self.transform = transform

    def __getitem__(self, item):
        img_path, label = self.info_list[item].split()
        img = Image.open(os.path.join(self.root, img_path).replace("\\","/"))
        img = self.transform(img)
        label = torch.tensor(int(label), dtype=torch.int64)
        return img, label

    def __len__(self):
        return self.info_list.__len__()


def get_dataloaders(path,bs,num_workers=0):
    """
    数据集结构应当如下:
    data/mydataset/
        |-- class 1
            |-- 0001.jpg
            |-- 0002.jpg
            |-- ...
        |-- class 2
            |-- 0001.jpg
            |-- 0002.jpg
            |-- ...
        train_list.txt
        val_list.txt
        test_list.txt
    @param path: ImageNet-like数据集根目录
    @param bs: batchsize
    @param num_workers: 线程数
    @return:
    """
    mu, st = 0, 1
    test_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((44, 44)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(mu,), std=(st,))
    ])
    train_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((44, 44)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(mu,), std=(st,))
    ])

    train = FolderDatasetme(root=path, label_file="train_list.txt", transform=train_transform)
    val = FolderDatasetme(root=path, label_file="val_list.txt", transform=test_transform)
    test = FolderDatasetme(root=path, label_file="test_list.txt", transform=test_transform)

    trainloader = DataLoader(train, batch_size=bs, shuffle=True, num_workers=num_workers)
    valloader = DataLoader(val, batch_size=bs, shuffle=True, num_workers=num_workers)
    testloader = DataLoader(test, batch_size=bs, shuffle=True, num_workers=num_workers)

    return trainloader, valloader, testloader


if __name__ == "__main__":
    train_loader, val_loader, test_loader = get_dataloaders(path=r"datasetme/",bs=4,num_workers=0)
    data=next(iter(train_loader))
    images,targets=data  # bs, ten,1,h,w
    images=images.view(-1,1,images.size(-2),images.size(-1))
    from visual.visual_tensor import show_tensor
    show_tensor(tensor=images,nrow=4,save_path="train_loader.jpg")


猜你喜欢

转载自blog.csdn.net/qq_40243750/article/details/129902170