Dataset
Pytorch中数据集被抽象为一个抽象类torch.utils.data.Dataset
,所有的数据集都应该继承这个类,并override以下两项:
- __len__
:代表样本数量。len(obj)
等价于obj.__len__()
。
- __getitem__
:返回一条数据或一个样本。obj[index]
等价于obj.__getitem__
。建议将节奏的图片等高负载的操作放到这里,因为多进程时会并行调用这个函数,这样做可以加速。
dataset中应尽量只包含只读对象,避免修改任何可变对象。因为如果使用多进程,可变对象要加锁,但后面讲到的dataloader的设计使其难以加锁。如下面例子中的self.num
可能在多进程下出问题:
class BadDataset(Dataset):
def __init__(self):
self.datas = range(100)
self.num = 0 # read data times
def __getitem__(self, index):
self.num += 1
return self.datas[index]
Dataloader
Dataset
负责表示数据集,它可以每次使用__getitem__
返回一个样本。而torch.utils.data.Dataloader
提供了对batch的处理,如shuffle等。Dataset
被封装在了Dataloader
中。
Dataloader
的构造函数如下:
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
部分参数解释如下:
- num_workers
:使用的子进程数,0
为不使用多进程。
- - worker_init_fn
: 默认为None
,如果不是None
,这个函数将被每个子进程以子进程id([0, num_workers - 1]之间的数)调用
- sample
:采样策略,若这个参数有定义,则shuffle
必须为False
- pin_memory
:是否将tensor数据复制到CUDA pinned memory中,pin memory中的数据转到GPU中会快一些
- drop_last
:当dataset中的数据数量不能整除batch size时,是否把最后
个数据丢掉
- collate_fn
:把一组samples打包成一个mini-batch的函数。可以自定义这个函数以处理损坏数据的情况(先在__getitem__
函数中将这样的数据返回None
,然后再在collate_fn
中处理,如丢掉损坏数据or再从数据集里随机挑一张),但最好还是确保dataset里所有数据都能用。
另外,Dataloader
是个iterable,可以进行相关迭代操作。
DataLoaderIter
Dataset
、Dataloader
和DataLoaderIter
是层层封装的关系,最终在内部使用DataLoaderIter
进行迭代。