Pytorch源码解读-torch.utils.data

torch.utils.data

Pytorch读取训练集需要用到torch.utils.data类,data类包括13个成员,主要用到的2个:

  1. class torch.utils.data.Dataset
  2. class torch.utils.data.DataLoader(datasetbatch_size=1shuffle=Falsesampler=Nonebatch_sampler=Nonenum_workers=0collate_fn=<function default_collate>pin_memory=Falsedrop_last=Falsetimeout=0worker_init_fn=None)

1. class torch.utils.data.Dataset 

An abstract class representing a Dataset.

# 一个用来表示数据集的抽象类

All other datasets should subclass it. All subclasses should override __len__, that provides the size of the dataset, and __getitem__, supporting integer indexing in range from 0 to len(self) exclusive.

# 其他所有的数据集都应该是这个类的子类,并且需要重载__len____getitem__

__len__提供数据集的大小;

__getitem__提供数据集的索引

2. class torch.utils.data.DataLoader(datasetbatch_size=1shuffle=Falsesampler=Nonebatch_sampler=Nonenum_workers=0collate_fn=<function default_collate>pin_memory=Falsedrop_last=Falsetimeout=0worker_init_fn=None)

Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.

数据加载器,包括数据集和数据提取策略。

Parameters:
  • dataset (Dataset) – dataset from which to load the data.加载的数据集
  • batch_size (intoptional) – how many samples per batch to load (default: 1).
  • shuffle (booloptional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampleroptional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.
  • batch_sampler (Sampleroptional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
  • num_workers (intoptional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)指定多少子过程用于加载数据,0表示只在主过程加载。
  • collate_fn (callableoptional) – merges a list of samples to form a mini-batch.
  • pin_memory (booloptional) – If True, the data loader will copy tensors into CUDA pinned memory before returning them.是否在返回前将Tensors保存进CUDA
  • drop_last (booloptional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • timeout (numericoptional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callableoptional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

例子 

首先,定义一个新的class ImageList,按照要求继承于Dataset

import torch.utils.data as data

from PIL import Image
import os
import os.path

def default_loader(path):
    img = Image.open(path).convert('L')
    return img

def default_list_reader(fileList):
    imgList = []
    with open(fileList, 'r') as file:
        for line in file.readlines():
            imgPath, label = line.strip().split('    ')
            imgList.append((imgPath, int(label)))
    return imgList

class ImageList(data.Dataset):
    def __init__(self, root, fileList, transform=None, list_reader=default_list_reader, loader=default_loader):
        self.root      = root
        self.imgList   = list_reader(fileList)
        self.transform = transform
        self.loader    = loader

    def __getitem__(self, index):
        imgPath, target = self.imgList[index]
        img = self.loader(os.path.join(self.root, imgPath))

        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.imgList)

然后,使用ImageList生成一个自己的Dataset,其中ImageList(...)是用来生成一个指定的Dataset

#load image
    train_loader = torch.utils.data.DataLoader(
        ImageList(root=root_path, fileList=train_list, 
            transform=transforms.Compose([                 
                transforms.Grayscale(),
                transforms.RandomCrop(128),
                transforms.RandomHorizontalFlip(), 
                transforms.ToTensor(),                
                #transforms.Normalize([255.0],[0])
            ])),
        batch_size=batch_size, shuffle=True,
        num_workers=workers, pin_memory=True)

猜你喜欢

转载自blog.csdn.net/alfred_torres/article/details/82835539
今日推荐