从零学习pytorch 第3课 DataLoader类运行过程

课程目录(在更新,喜欢加个关注点个赞呗):
从零学习pytorch 第1课 搭建一个超简单的网络
从零学习pytorch 第1.5课 训练集、验证集和测试集的作用
从零学习pytorch 第2课 Dataset类
从零学习pytorch 第3课 DataLoader类运行过程
从零学习pytorch 第4课 初见transforms
从零学习pytorch 第5课 PyTorch模型搭建三要素
从零学习pytorch 第5.5课 Resnet34为例学习nn.Sequential和模型定义
从零学习PyTorch 第6课 权值初始化
从零学习PyTorch 第7课 模型Finetune与预训练模型
从零学习PyTorch 第8课 PyTorch优化器基类Optimier
上一课中讲解了,如何构建Dataset子类,也就是MyDataset。MyDataset中,主要是获取图片的索引以及标签,但是触发Dataset去读取图片以及标签却是在DataLoader中实现的。咱们这一课一步一步来,看图片如何从硬盘,流到模型中的。


  1. train_data = MyDataset(txt_path = train_txt_path,…)
  2. train_loader = DataLoader(dataset=train_data,…)
  3. for i, data in enumerate(train_loader,0)
  4. class DataLoader():def __iter__(self): return _DataLoaderIter(self)
  5. _DataLoaderIter(): def __next__(self): batch = self.collate_fn([self.dataset[i] for i in indices])
  6. class MyDataset(): def __getitem__():img = Image:open(fn).convert(‘RGB’)
  7. class MyDataset(): img = self.transform(img)
  8. inputs,labels = data; inputs,labels = Variable(inputs),Variable(labels)
  9. output = net(inputs)

这个光用代码还是很难讲清楚,我再好好解释一下。
总之就是从MyDataset中来,再回到MyDataset中去
一开始通过MyDataset,创建一个有txt路径,有读取图片的方法的函数
然后就是pytorch自己规范化的流程,在第六步才会调用MyDataset中的__getitem__()函数,通过Image.open(读取图片)
然后对数据进行预处理,也就是transform,然后将数据转换成Variable类型,就成为模型的输入


还是没懂?没事我们再来看一遍

  1. MyDataset类初始化一个实例,txt中有图片路径和标签
  2. 初始化DataLoader的时候,把train_data传入,从而使DataLoader拥有图片的路径
  3. 循环的每一次循环我们叫iteration,每一个iteration读取一个batch的图片数据,这里的data是一个batch的图片数据和标签,是一个list
  4. 在class DataLoader中再调用class _DataLoaderIter()
  5. 在_DataLoaderItem()类中会执行__next__(self)函数,这个函数会通过indeces=next(self.sample_iter)获取到一个batch的indices后再通过,batch = self.collate_fn([self.dataset[i] for i in indices]) 获取到数据。
    这里batch = self.collate_fn([self.dataset[i] for i in indices]) 函数会调用到self.collate_fn函数
  6. 这个self.collate_fn中会调用到MyDataset类中的__getitem__()函数,在__getitem__()中通过Image.open(fn).convert(‘RGB’)读取图片
  7. 通过Image.open(fn).convert(‘RGB’)读取图片之后,会对图片进行预处理,就是transform,返回的还是img和label,再通过self.collate_fn来拼接成一个batch。一个batch是一个list,有两个元素,第一个元素是图片数据,是一个4D的Tensor变量,shape为(64,3,32,32),第二个元素是标签,shape为64。64说明一个batch中有64个图片
  8. 最后一步就是将图片数据转化成Variable类型,然后才是模型的真正输入

希望大家都理解了这个过程。

最后从个人角度,再补充几点。

  • __iter__(self)什么时候会调用的,就是在被循环的时候。如果循环了,每一个iteration都会调用一次__iter__(self)。这也是第三步和第四步连接起来的关键。

整个过程的难点就在于一个循环的过程,前两步主要是展示两个类的输入参数是什么。第三步我们通过对DataLoader的循环,来获取一个又一个的batch。循环中,我们自动调用了__iter__(self),而__iter__(self)返回了个_DataLoaderIter,而_DataLoaderIter中自动执行了next函数,得到了indeces,然后执行了__next__(self),通过indeces得到了batch。关键在得到batch的这一步,我们得到batch的那个函数collate_fn自动调用了Dataset类中的__getitem__

说白了就是循环的自动调用机制。不懂了可以提问!

通过里哦阿姐图片从硬盘到模型的过程,我们可以更好地对数据进行处理。

发布了78 篇原创文章 · 获赞 14 · 访问量 9736

猜你喜欢

转载自blog.csdn.net/qq_34107425/article/details/104100870
今日推荐