从零学习pytorch 第2课 Dataset类

这一次主要讲解pytorch读取数据的机制和流程,然后按照流程编写代码

Dataset基类

PyTorch 读取图片,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。Dataset
类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。
看一下源码:
在这里插入图片描述
这里有一个getitem函数,getitem函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

list的制作方法通常是将图片的路径标签信息存储在一个txt中,然后从txt中读取,所以总结一下基本流程:

  1. 制作存储了图片路径和标签信息的txt
  2. 将这些信息转化成list,list的每一个元素对应一个样本
  3. 通过getitem函数,读取数据和标签。

其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再将)。

总而言之,要让PyTorch读取自己的数据集,只要两步

  1. 制作图片数据的索引
  2. 构建Dataset子类

制作图片数据索引

非常简单,就是一些基本的操作,百度一下“”python如何保存txt文件“”就可以知道了。
然后一般来说,txt都是这样的格式
./Data/train/01.png 0
./Data/train/02.png 0
./Data/train/03.png 1
./Data/train/04.png 1

构建Dataset子类

下面我们构建一下Dataset的子类,叫他MyDataset类:

from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Datset):
    def __init__(self,txt_path,transform=None,target_transform=None):
        fh = open(txt_path,'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0].int(words[1])))
        self.imgs = imgs
        self.transform = transform
    def __getitem__(self,index):
        fn,label = self.imgs[index]
        img=Image.open(fn).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img,label
    def __len__(self):
        return len(sefl.imgs)

Init

  • 初始化中,我们从已经准备好的txt中获取了图片的路径和表亲啊,并且春初在self.imgs这意味着self.imgs是一个list就像上面我们讲的那样

  • 初始化中,初始化了transform,transform是一个Compose类型,transform中包含一个list,list中定义了各种对图像进行的操作,比如随机剪裁,旋转反转等。

  • 一个图片都进来之后,会经过数据处理(数据增强),最终变成另外一张图片,也就是模型的输入数据。但是PyTorch的数据增强是将原始图片进行处理,是不会生成新的图片。因此假如我们使用randomcrop这样的随机操作的时候,每次epoch输入进来的图片不会是一摸一样的,达到样本多样性的功能

getitem

  • self.imgs是一个list,每一个元素都是一个二元tuple,这很好理解(str1,str2)这样的
  • 利用Image.open对图片进行读取,img类型为Image,mode=‘RGB’
  • 用transform对图片进行处理,里面可能有什么 标准化(减均值除以标准差),随机剪裁什么的(后面会细说)

这样Mydataset就构建好了,剩下的操作就交给DataLoader,在DataLoader中,会触发Mydataset中的getitem函数读取一张图片的数据和标签,并将多个图片拼接成一个batch返回,每一个batch才是模型真正的输入。

下一章节会讲解DataLoader是如何获取一个batch的

发布了47 篇原创文章 · 获赞 4 · 访问量 2255

猜你喜欢

转载自blog.csdn.net/qq_34107425/article/details/104097402
今日推荐