PyTorch 迭代器读取数据

from torch.utils.data import Dataset

class MetaDataset(Dataset):
  def __init__(self, n_episode, value):
    self.value = value
    self.n_episode = n_episode
  
  def set_iter(self):
    self.iterator = self._iter()

  def _iter(self):
    i = 0
    while True:
      yield self.value + i
      i += 1

  def __getitem__(self, i):
    return next(self.iterator)

  def __len__(self):
    return self.n_episode

if __name__ == "__main__":
  dataset = MetaDataset(5, 1)
  print("dataset length:", len(dataset))
  dataset.set_iter()
  for i in range(len(dataset)):
    print(dataset[i])

输出:

dataset length: 5
1
2
3
4
5

Version 1:想着在__getitem__里面调用迭代器_iter(),每次getitem就取一次值,结果发现拿到的是一个function,没法用。

from torch.utils.data import Dataset

class MetaDataset(Dataset):
  def __init__(self, n_episode, value):
    self.value = value
    self.n_episode = n_episode
  
  def _iter(self):
    i = 0
    while True:
      yield self.value + i
      i += 1

  def __getitem__(self, i):
    v = self._iter()
    return v

  def __len__(self):
    return self.n_episode

if __name__ == "__main__":
  dataset = MetaDataset(5, 1)
  print("dataset length:", len(dataset))
  for i in range(len(dataset)):
    print(dataset[i])

输出:

dataset length: 5
<generator object MetaDataset._iter at 0x7effac314d60>
<generator object MetaDataset._iter at 0x7effac314d60>
<generator object MetaDataset._iter at 0x7effac314d60>
<generator object MetaDataset._iter at 0x7effac314d60>
<generator object MetaDataset._iter at 0x7effac314d60>

在这里插入图片描述
Version 2:查了一下之后,发现应该用next才能取到迭代器的值,于是加了next,现在能拿到值了,但是只能取第一个值。为啥?因为把迭代器的初始化操作v=self.iter()放在了getitem里面,那么每次getitem实际上都会重新初始化迭代器。

from torch.utils.data import Dataset

class MetaDataset(Dataset):
  def __init__(self, n_episode, value):
    self.value = value
    self.n_episode = n_episode
  
  def _iter(self):
    i = 0
    while True:
      yield self.value + i
      i += 1

  def __getitem__(self, i):
    v = self._iter()
    return next(v)    # 使用next()

  def __len__(self):
    return self.n_episode

if __name__ == "__main__":
  dataset = MetaDataset(5, 1)
  print("dataset length:", len(dataset))
  for i in range(len(dataset)):
    print(dataset[i])

输出:

dataset length: 5
1
1
1
1
1

在这里插入图片描述
Version 3:既然这样,那就把迭代器的初始化放到__init__的时候去做,然后发现果然work

from torch.utils.data import Dataset

class MetaDataset(Dataset):
  def __init__(self, n_episode, value):
    self.value = value
    self.n_episode = n_episode
  	
    def _iter():
        i = 0
        while True:
            yield self.value + i
            i += 1
            
    self.iterator = _iter()

  def __getitem__(self, i):
    return next(self.iterator)

  def __len__(self):
    return self.n_episode

if __name__ == "__main__":
  dataset = MetaDataset(5, 1)
  print("dataset length:", len(dataset))
  for i in range(len(dataset)):
    print(dataset[i])

输出:

dataset length: 5
1
2
3
4
5

在这里插入图片描述
Version 4:我想着实际代码里面肯定不能这么写,因为初始化的时候很多函数都在里面,所以就加了一个init,专门用来初始化迭代器。

from torch.utils.data import Dataset

class MetaDataset(Dataset):
  def __init__(self, n_episode, value):
    self.value = value
    self.n_episode = n_episode
  
  # 初始化迭代器
  def set_iter(self):
    self.iterator = self._iter()

  def _iter(self):
    i = 0
    while True:
      yield self.value + i
      i += 1

  def __getitem__(self, i):
    return next(self.iterator)

  def __len__(self):
    return self.n_episode

if __name__ == "__main__":
  dataset = MetaDataset(5, 1)
  print("dataset length:", len(dataset))
  dataset.set_iter()
  for i in range(len(dataset)):
    print(dataset[i])

输出:

dataset length: 5
1
2
3
4
5

在这里插入图片描述
Version 5:好像直接在__init__里面初始化iterator也可以……??事实证明没毛病。

from torch.utils.data import Dataset

class MetaDataset(Dataset):
  def __init__(self, n_episode, value):
    self.value = value
    self.n_episode = n_episode
    self.iterator = self._iter()   # 直接在init里初始化迭代器
  
  def _iter(self):
    i = 0
    while True:
      yield self.value + i
      i += 1

  def __getitem__(self, i):
    return next(self.iterator)

  def __len__(self):
    return self.n_episode

if __name__ == "__main__":
  dataset = MetaDataset(5, 1)
  print("dataset length:", len(dataset))
  for i in range(len(dataset)):
    print(dataset[i])

输出:

dataset length: 5
1
2
3
4
5

在这里插入图片描述

Version 6:

这种方式不管是enumerate还是data_loader[i]都可以拿到那个iterate_dataset()生成的元素!
在这里插入图片描述

本质上就是dataset类可以视为list,既能enumerate访问也能dataset[index]访问

dataset类只需要改写__getitem__和__len__方法,TensorDataset是这两个方法
IterableDataset是__iter__和__getitem__

但是pytorch的dataloader类只是一个迭代器,他只能enumerate访问,不能dataloader[index]访问(我才发现这一点)。dataloader只重写了__iter__,没有重写其他的

在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_31347869/article/details/125777859