PyTorch中DataLoader与Dataset的使用、关系&原理

DataLoader(torch.utils.data.DataLoader)

功能:构建可迭代的数据装载器

类中的几个主要变量定义功能介绍如下,除此之外还有11个参数
dataset:Dataset类,决定数据从哪里读取以及如何读取
batchsize:批处理的大小
num_works:是否多进程读取数据
shuffle:每个epoch是否乱序 (true or false)
drop_last:当样本数不能被batchsize整除时,是否丢弃最后一批数据。(true or false)

Dataset(torch.utils.data.Dataset)

功能:Dataset是一个抽象类,所有自定义的Dataset都需要继承这个父类,并且重载其中的__getitem__()函数。
getitem:输入是索引,输出是样本(加标签),也就是定义了索引到样本的映射规则。

代码示例

以图像的分类任务为例(代码中具体是人民币图像识别1元与100元,进行面值二分类任务),重点观察dataset与dataloader如何完成数据读取的任务。
首先主函数中有如下定义

train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

这里的RMBDataset就是继承了Dataset的子类了,有了RMBDataset,就可以将其传递到Dataloader中,构建可迭代的数据装载器了。最后只需要用for循环迭代train_loader或者valid_loader,这样每一次的迭代返回的就是一批训练数据及对应的标签。代码简要省略版示例如下

for i, data in enumerate(train_loader):
	inputs, labels = data
	outputs = neural_net_work(inputs)
	loss = criterion(outputs, labels)
	loss.backward()

而如何将原始的数据集处理成一批批的训练数据与其标签?如何实现这样的迭代规则?这些内容都在Dataset中进行定义,后文会说明。首先看for循环,for循环会找到dataloader类中的next函数来得到这轮要拿数据:

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

这里的batch就是一轮迭代最终返回的数据,无视掉它,主要关注两个地方,首先是indices = next(self.sample_iter)这句话是如何获取index的,这句话会调取sampler中的iter函数如下,这里的sampler就是一个采样器

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

可见这个sampler的作用就是告诉我们,每一个batch应该读取哪些数据。这里通过了self.sampler来获取了index,self.sampler就不再继续拓展,其作用就是对样本的index顺序进行随机打乱之后形成一个列表,以供迭代性读取。
回到上面的next函数,上面说了主要关注两个地方,那么剩下第二个地方就是batch = self.collate_fn([self.dataset[i] for i in indices])这句话了。这句话就实现了具体的数据读取
这里就正式调用了dataset,对dataset输入index索引,来返回数据。那么索引是如何通过dataset返回数据呢,这句话的实现就需要通过一开始定义的RMBDataset(该类继承了Dataset)这个类中的getitem函数了,函数内容如下

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = PIL.Image.open(path_img).convert('RGB')  # 0~255

        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为tensor或者其他形式的数据

        return img, label

这个函数就是Dataset中最核心的部分,函数第一句中的data_info是已经通过类内其他函数提取好的数据集信息,包含了图片的路径信息以及标签信息。经过getitem函数之后返回的就是想要的输入数据以及标签了。
回到batch = self.collate_fn([self.dataset[i] for i in indices])这句话,这句话中的collate_fn的功能其实就是一个数据整理器,将得到的呈列表形式的16个数据整理成一个batch的形式,在这个batch(这里的batch是list的形式)里面有两个元素,一个是关于16张输入图片整合后的tensor,一个是标签的tensor。

总结

Dataloader读取的是sampler输出的index索引,读哪些数据是sampler决定的。(读哪些)
Dataloader一般会在定义dataset时要求输入数据在硬盘中的存储路径data_dir。(在哪读)
通过Dataset的getitem,通过Dataloader调用的sampler给出的索引读取数据。(怎么读)

发布了33 篇原创文章 · 获赞 45 · 访问量 2505

猜你喜欢

转载自blog.csdn.net/nstarLDS/article/details/104673127