Pytorch学习第一讲:数据加载

Pytorch官网上给出了一个关于加载CIFAR10数据集的例子:

主要使用了torchvision数据包,里面有一些ImageNet,CIFAR-10和MNIST等常见数据集。加载数据分成三个部分:

  • torchvision.transforms(数据预处理)
  • torchvision.datasets(数据集读取)
  • torchvision.DataLoader(数据集加载)

代码实现如下:

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

但是例子中给出的形式是分类数据,即数据是图片,标签是一个字符的这种形式。而我在项目中需要的是标签也是同样的和图片维度的图片,所以,我就自己定义了一个数据加载器,来加载我的训练数据:

import torch.utils.data as data
import os
import os.path
import glob
import torchvision.transforms as transforms
from PIL import Image


def make_dataset(root, train=True):
    dataset = []
    if train:
        dirgt = os.path.join(root, 'data\\label')
        dirimg = os.path.join(root, 'data\\image')
        for fGT in glob.glob(os.path.join(dirgt, '*.jpg')):
            fName = os.path.basename(fGT)
            dataset.append([os.path.join(dirimg, fName), os.path.join(dirgt, fName)])
    return dataset



class mytraindata(data.Dataset):

    def __init__(self, root, transform=None, train=True, rescale=None):
        self.train = train
        self.transform = transform
        self.rescale = rescale
        if self.train:
            self.train_set_path = make_dataset(root, train)

    def __getitem__(self, idx):
        if self.train:
            img_path, label_path = self.train_set_path[idx]
            img = Image.open(img_path)
            if self.rescale:
                img = img.resize((224, 224))
            if self.transform:
                transform = transforms.ToTensor()
                img = transform(img)

            label = Image.open(label_path)
            if self.rescale:
                label = label.resize((224, 224))
            if self.transform:
                transform = transforms.ToTensor()
                label = transform(label)
            return img, label

    def __len__(self):
        return len(self.train_set_path)

主要来看class mytraindata这个类,这个类继承了data.Dataset这个父类,里面主要有三个函数:

  • __init__(self): 初始化函数
  • __getitem__(self, idx):一个个数据去读取的函数
  • __len__(self):返回这个训练集的长度

主要看getitem的返回项,我这里返回image,和label,所以就根据你需要的训练数据集的格式来编写这个数据加载器就可以了,另外,数据需要的一些预处理也都是在getitem里面实现了,比如缩放到224*224, 转换成tensor的格式等。

在调用这个类来加载数据的时候(在主函数里调用),只需要写:

dataset = mytraindata(".", transform=True, train=True, rescale=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
for epoch in range(100000):
    for i, data in enumerate(data_loader, 0):
        inputs, labels = data
inputs 和labels就是你需要的训练的数据和标签了。


猜你喜欢

转载自blog.csdn.net/vivianyzw/article/details/80969191