from torch.utils.data import Dataset
class MetaDataset(Dataset):
def __init__(self, n_episode, value):
self.value = value
self.n_episode = n_episode
def set_iter(self):
self.iterator = self._iter()
def _iter(self):
i = 0
while True:
yield self.value + i
i += 1
def __getitem__(self, i):
return next(self.iterator)
def __len__(self):
return self.n_episode
if __name__ == "__main__":
dataset = MetaDataset(5, 1)
print("dataset length:", len(dataset))
dataset.set_iter()
for i in range(len(dataset)):
print(dataset[i])
输出:
dataset length: 5
1
2
3
4
5
Version 1:想着在__getitem__
里面调用迭代器_iter()
,每次getitem就取一次值,结果发现拿到的是一个function,没法用。
from torch.utils.data import Dataset
class MetaDataset(Dataset):
def __init__(self, n_episode, value):
self.value = value
self.n_episode = n_episode
def _iter(self):
i = 0
while True:
yield self.value + i
i += 1
def __getitem__(self, i):
v = self._iter()
return v
def __len__(self):
return self.n_episode
if __name__ == "__main__":
dataset = MetaDataset(5, 1)
print("dataset length:", len(dataset))
for i in range(len(dataset)):
print(dataset[i])
输出:
dataset length: 5
<generator object MetaDataset._iter at 0x7effac314d60>
<generator object MetaDataset._iter at 0x7effac314d60>
<generator object MetaDataset._iter at 0x7effac314d60>
<generator object MetaDataset._iter at 0x7effac314d60>
<generator object MetaDataset._iter at 0x7effac314d60>
Version 2:查了一下之后,发现应该用next
才能取到迭代器的值,于是加了next,现在能拿到值了,但是只能取第一个值。为啥?因为把迭代器的初始化操作v=self.iter()
放在了getitem里面,那么每次getitem实际上都会重新初始化迭代器。
from torch.utils.data import Dataset
class MetaDataset(Dataset):
def __init__(self, n_episode, value):
self.value = value
self.n_episode = n_episode
def _iter(self):
i = 0
while True:
yield self.value + i
i += 1
def __getitem__(self, i):
v = self._iter()
return next(v) # 使用next()
def __len__(self):
return self.n_episode
if __name__ == "__main__":
dataset = MetaDataset(5, 1)
print("dataset length:", len(dataset))
for i in range(len(dataset)):
print(dataset[i])
输出:
dataset length: 5
1
1
1
1
1
Version 3:既然这样,那就把迭代器的初始化放到__init__
的时候去做,然后发现果然work
from torch.utils.data import Dataset
class MetaDataset(Dataset):
def __init__(self, n_episode, value):
self.value = value
self.n_episode = n_episode
def _iter():
i = 0
while True:
yield self.value + i
i += 1
self.iterator = _iter()
def __getitem__(self, i):
return next(self.iterator)
def __len__(self):
return self.n_episode
if __name__ == "__main__":
dataset = MetaDataset(5, 1)
print("dataset length:", len(dataset))
for i in range(len(dataset)):
print(dataset[i])
输出:
dataset length: 5
1
2
3
4
5
Version 4:我想着实际代码里面肯定不能这么写,因为初始化的时候很多函数都在里面,所以就加了一个init,专门用来初始化迭代器。
from torch.utils.data import Dataset
class MetaDataset(Dataset):
def __init__(self, n_episode, value):
self.value = value
self.n_episode = n_episode
# 初始化迭代器
def set_iter(self):
self.iterator = self._iter()
def _iter(self):
i = 0
while True:
yield self.value + i
i += 1
def __getitem__(self, i):
return next(self.iterator)
def __len__(self):
return self.n_episode
if __name__ == "__main__":
dataset = MetaDataset(5, 1)
print("dataset length:", len(dataset))
dataset.set_iter()
for i in range(len(dataset)):
print(dataset[i])
输出:
dataset length: 5
1
2
3
4
5
Version 5:好像直接在__init__
里面初始化iterator也可以……??事实证明没毛病。
from torch.utils.data import Dataset
class MetaDataset(Dataset):
def __init__(self, n_episode, value):
self.value = value
self.n_episode = n_episode
self.iterator = self._iter() # 直接在init里初始化迭代器
def _iter(self):
i = 0
while True:
yield self.value + i
i += 1
def __getitem__(self, i):
return next(self.iterator)
def __len__(self):
return self.n_episode
if __name__ == "__main__":
dataset = MetaDataset(5, 1)
print("dataset length:", len(dataset))
for i in range(len(dataset)):
print(dataset[i])
输出:
dataset length: 5
1
2
3
4
5
Version 6:
这种方式不管是enumerate还是data_loader[i]都可以拿到那个iterate_dataset()生成的元素!
本质上就是dataset类可以视为list,既能enumerate访问也能dataset[index]访问
dataset类只需要改写__getitem__和__len__方法,TensorDataset是这两个方法
IterableDataset是__iter__和__getitem__
但是pytorch的dataloader类只是一个迭代器,他只能enumerate访问,不能dataloader[index]访问(我才发现这一点)。dataloader只重写了__iter__,没有重写其他的