Pytorch Dataset和DataLoader 加载训练数据

Dataset 基类

torch.utils.data.Dataset 为数据集的基类, 继承这个基类,我们能够非常快速的实现对数据的加载。

我们要实现自己加载数据的类,并继承于Dataset 这个类,重载类的成员函数
1、__1en__方法, 能够实现通过全局的len()方法获取其中的元素个数;
2、getitem 方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第 i条数据。

from torch.utils.data import Dataset, DataLoader
# 完成数据集类
class MyDataset(Dataset):
    def __init__(self):
    def __getitem__(self, index):
        """ 必须实现,作用是:获取索引对应位置的一条数据 :param index: :return: """
        return to_tensor(self.data[index])
    def __len__(self):
        """ 必须实现,作用是得到数据集的大小 :return: """
        return len(self.data)
    def to_tensor(data):
        return torch.from_numpy(data)

使用Dataset 能够进行数据的读取,但是还需要如下实现:

批处理数据(Batching the data)
打乱数据(Shuffling the data)
使用多线程multiprocessing并行加载数据

定义好 Dataset 之后就可以用DataLoader进行加载。

DataLoader 调用一句话即可,dataset 指向 自定义的读取数据类。

data_loader = DataLoader(dataset=data, batch_size=2, shuffle=True, num_workers=2)

参数:
1、dataset:提前定义的dataset的实例;
2、batch_size:传入数据的batch大小,常常是32、64
3、shuffle:bool类型,打乱数据;
4、num_workers:加载数据的线程数。
5、drop_last:bool类型,为真,表示最后的数据不足一个batch,就删掉

迭代遍历:

   for step, (batch_x, batch_y) in enumerate(data_loader):
        print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

猜你喜欢

转载自blog.csdn.net/long630576366/article/details/124863780