pytorch DataLoader 自定义数据集

pytorch 提供了一种数据处理的方式,能够生成mini-batch的数据,在训练和测试的时候进行多线程处理,加快准备数据的速度。这个函数工具是

torch.utils.data import Dataset, DataLoader

其中Dataset是我们定义自己的多线程数据处理框架的父类,我们定义的框架要继承这个类
下面简单定义数据准备的框架吧!!!

from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
    def __init__(self, filepath, transform=None,keys = None, target_transform=None):
        pass
    '''
    首先说明一下以上的初始化参数,filepath是数据集的路径,transform是对源数据(features)的一些变化,target_transform是对目标数据(labels)的一些变换,keys是键,因为我的数据是这样的,整体是字典格式的,每个键对应的值又是ndarray数据,所以我通过键来索引对应的值 
    def __getitem__(self,index):
        padd
    def __len__(self):
        pass
    '''

接下来结合数据来解释以下代码:

class MyDataset(Dataset):
    def __init__(self, filepath, transform=None,keys = None, target_transform=None):
        with open(filepath,'rb') as f:
            self.data = pickle.load(f)
        self.keys = keys
        self.input_seq = self.data[self.keys[0]]  ### 输入序列
        self.output_seq = self.data[self.keys[1]]   #### 输出序列
        self.transform = transform #### 对输入序列进行变换
        self.target_transform = target_transform   ###### 对输出序列进行变换

    def __getitem__(self, index):
        input_seq,output_seq = self.input_seq[index],self.output_seq[index]  ## 按照索引迭代读取内容
        if self.transform is not None:
            input_seq = self.transform(input_seq)
            output_seq = self.transform(output_seq)
        return input_seq,output_seq  ### 直接输出输入序列和输出序列

    def __len__(self):
        return self.data[self.keys[0]].shape[0]   ### 返回的是样本集的大小,样本的个数

train_data = MyDataset(filepath = 'train30.pickle',keys = ['aa','bb'])
test_data = MyDataset(filepath = 'test30.pickle',keys = ['aa','bb'])
train_loader = DataLoader(dataset = train_data,batch_size = 32,shuffle = False)
test_loader = DataLoader(dataset = test_data,batch_size = 32,shuffle = False)

可以通过调用train_loader和 test_loader 来调取mini-batch的数据,batch_size在train_loader和test_loader中值都已经设好了,通过以下代码批量调取:

from torch.autograd import Variable
for i,(input_seq,out_seq) in enumerate(train_loader):
    input_seq = Variable(input_seq.cuda())
    output_seq = Variable(output_seq.cuda())

猜你喜欢

转载自blog.csdn.net/baidu_36161077/article/details/81062980
今日推荐