pytorch 提供了一种数据处理的方式,能够生成mini-batch的数据,在训练和测试的时候进行多线程处理,加快准备数据的速度。这个函数工具是
torch.utils.data import Dataset, DataLoader
其中Dataset是我们定义自己的多线程数据处理框架的父类,我们定义的框架要继承这个类
下面简单定义数据准备的框架吧!!!
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
def __init__(self, filepath, transform=None,keys = None, target_transform=None):
pass
'''
首先说明一下以上的初始化参数,filepath是数据集的路径,transform是对源数据(features)的一些变化,target_transform是对目标数据(labels)的一些变换,keys是键,因为我的数据是这样的,整体是字典格式的,每个键对应的值又是ndarray数据,所以我通过键来索引对应的值
def __getitem__(self,index):
padd
def __len__(self):
pass
'''
接下来结合数据来解释以下代码:
class MyDataset(Dataset):
def __init__(self, filepath, transform=None,keys = None, target_transform=None):
with open(filepath,'rb') as f:
self.data = pickle.load(f)
self.keys = keys
self.input_seq = self.data[self.keys[0]] ### 输入序列
self.output_seq = self.data[self.keys[1]] #### 输出序列
self.transform = transform #### 对输入序列进行变换
self.target_transform = target_transform ###### 对输出序列进行变换
def __getitem__(self, index):
input_seq,output_seq = self.input_seq[index],self.output_seq[index] ## 按照索引迭代读取内容
if self.transform is not None:
input_seq = self.transform(input_seq)
output_seq = self.transform(output_seq)
return input_seq,output_seq ### 直接输出输入序列和输出序列
def __len__(self):
return self.data[self.keys[0]].shape[0] ### 返回的是样本集的大小,样本的个数
train_data = MyDataset(filepath = 'train30.pickle',keys = ['aa','bb'])
test_data = MyDataset(filepath = 'test30.pickle',keys = ['aa','bb'])
train_loader = DataLoader(dataset = train_data,batch_size = 32,shuffle = False)
test_loader = DataLoader(dataset = test_data,batch_size = 32,shuffle = False)
可以通过调用train_loader和 test_loader 来调取mini-batch的数据,batch_size在train_loader和test_loader中值都已经设好了,通过以下代码批量调取:
from torch.autograd import Variable
for i,(input_seq,out_seq) in enumerate(train_loader):
input_seq = Variable(input_seq.cuda())
output_seq = Variable(output_seq.cuda())