Pytorch中DataLoader的使用方法

在Pytorch中,torch.utils.data中的Dataset与DataLoader是处理数据集的两个函数,用来处理加载数据集。通常情况下,使用的关键在于构建dataset类。

一:dataset类构建。

在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。

class dataset:
    def __init__(self,...):
        ...
    def __len__(self,...):
        return n
    def __getitem__(self,item):
        return data[item]

正常情况下,该数据集是要继承Pytorch中Dataset类的,但实际操作中,即使不继承,数据集类构建后仍可以用Dataloader()加载的。

在dataset类中,__len__(self)返回数据集中数据个数,__getitem__(self,item)表示每次返回第item条数据。

二:DataLoader使用

在构建dataset类后,即可使用DataLoader加载。DataLoader中常用参数如下:

1.dataset:需要载入的数据集,如前面构造的dataset类。

2.batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个batch进行训练。

3.shuffle:是否在打乱数据集样本顺序。True为打乱,False反之。

4.drop_last:是否舍去最后一个batch的数据(很多情况下数据总数N与batch size不整除,导致最后一个batch不为batch size)。True为舍去,False反之。

三:举例

兔兔以指标为1,数据个数为100的数据为例。

import torch
from torch.utils.data import DataLoader

class dataset:
    def __init__(self):
        self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
        self.y=(torch.sin(self.x)+1)/2
    def __len__(self):
        return 100
    def __getitem__(self, item):
        return self.x[item],self.y[item]
data=DataLoader(dataset(),batch_size=10,shuffle=True)
for batch in data:
    print(batch)

当然,利用这个数据集可以进行简单的神经网络训练。

from torch import nn
data=DataLoader(dataset(),batch_size=10,shuffle=True)
bp=nn.Sequential(nn.Linear(1,5),
                 nn.Sigmoid(),
                 nn.Linear(5,1),
                 nn.Sigmoid())
optim=torch.optim.Adam(params=bp.parameters())
Loss=nn.MSELoss()
for epoch in range(10):
    print('the {} epoch'.format(epoch))
    for batch in data:
        yp=bp(batch[0])
        loss=Loss(yp,batch[1])
        optim.zero_grad()
        loss.backward()
        optim.step()

猜你喜欢

转载自blog.csdn.net/weixin_60737527/article/details/126754254