PyTorch中的Dataset、Dataloader和_DataloaderIter

版权声明:本文为博主原创文章,转载请注明: blog.csdn.net/gdymind https://blog.csdn.net/gdymind/article/details/82226509

Dataset

Pytorch中数据集被抽象为一个抽象类torch.utils.data.Dataset,所有的数据集都应该继承这个类,并override以下两项:
- __len__:代表样本数量。len(obj)等价于obj.__len__()
- __getitem__:返回一条数据或一个样本。obj[index]等价于obj.__getitem__。建议将节奏的图片等高负载的操作放到这里,因为多进程时会并行调用这个函数,这样做可以加速。

dataset中应尽量只包含只读对象,避免修改任何可变对象。因为如果使用多进程,可变对象要加锁,但后面讲到的dataloader的设计使其难以加锁。如下面例子中的self.num可能在多进程下出问题:

class BadDataset(Dataset):
    def __init__(self):
        self.datas = range(100)
        self.num = 0 # read data times
    def __getitem__(self, index):
        self.num += 1
        return self.datas[index]

Dataloader

官方documentation

Dataset负责表示数据集,它可以每次使用__getitem__返回一个样本。而torch.utils.data.Dataloader提供了对batch的处理,如shuffle等。Dataset被封装在了Dataloader中。

Dataloader的构造函数如下:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

部分参数解释如下:
- num_workers:使用的子进程数,0为不使用多进程。
- - worker_init_fn: 默认为None,如果不是None,这个函数将被每个子进程以子进程id([0, num_workers - 1]之间的数)调用
- sample:采样策略,若这个参数有定义,则shuffle必须为False
- pin_memory:是否将tensor数据复制到CUDA pinned memory中,pin memory中的数据转到GPU中会快一些
- drop_last:当dataset中的数据数量不能整除batch size时,是否把最后 len(dataset) mod batch_size 个数据丢掉
- collate_fn:把一组samples打包成一个mini-batch的函数。可以自定义这个函数以处理损坏数据的情况(先在__getitem__函数中将这样的数据返回None,然后再在collate_fn中处理,如丢掉损坏数据or再从数据集里随机挑一张),但最好还是确保dataset里所有数据都能用。

另外,Dataloader是个iterable,可以进行相关迭代操作。

DataLoaderIter

DatasetDataloaderDataLoaderIter是层层封装的关系,最终在内部使用DataLoaderIter进行迭代。

猜你喜欢

转载自blog.csdn.net/gdymind/article/details/82226509