PyTorch基础(四)-----数据加载和预处理

前言

之前已经简单讲述了PyTorch的Tensor、Autograd、torch.nn和torch.optim包,通过这些我们已经可以简单的搭建一个网络模型,但这是不够的,我们还需要大量的数据,众所周知,数据是深度学习的灵魂,深度学习的模型是由数据“喂”出来的,这篇我们来讲述一下数据的加载和预处理。

  • 首先,我们要引入torch包
import torch
torch.__version__

一、数据的加载

PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。

1.1 Dataset

Dataset是一个抽象类,为了能够方便的读取,需要将要使用的数据包装为Dataset类。自定义的Dataset类需要继承它并且实现2个成员方法:

  • 1.__getitem__():该方法定义用索引(0-len(self))获取一条数据或一个样本
  • 2.__len__():该方法返回数据集的总长度

下面我们使用Kaggle上的一个竞赛bluebook for bulldozers自定义一个数据集,为了方便介绍,我们使用里面的数据字典来做说明

  • 首先,我们需要引用相关的包
from torch.utils.data import Dataset
import pandas as pd
  • 自定义一个数据集
#定义一个数据集
class BulldozerDataset(Dataset):
    """ 数据集演示 """
    def __init__(self, csv_file):
        """实现初始化方法,在初始化的时候将数据读载入"""
        self.df=pd.read_csv(csv_file)
    def __len__(self):
        '''
        返回df的长度
        '''
        return len(self.df)
    def __getitem__(self, idx):
        '''
        根据 idx 返回一行数据
        '''
        return self.df.iloc[idx].SalePrice
  • 至此,我们的数据集已经定义完成了,我们可以实例化一个对象来访问
ds_demo= BulldozerDataset('median_benchmark.csv')
  • 我们可以直接使用如下命令查看数据集数据
# 前面我们已经实现了__len__方法,所以可以直接使用
len(ds_demo)
  • 使用索引可以直接访问对应的数据
ds_demo[0]

自定义的数据集已经创建好了,下面我们使用官方提供的数据载入器,读取数据

1.2 DataLoader

DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小)、shuffle(是否进行shuffle操作)、num_workers(加载数据时使用几个子进程)。下面做一个简单的演示:

dl = torch.utils.data.DataLoader(ds_demo,batch_size = 10,shuffle = True,num_workers = 0)

DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据

idata=iter(dl)
print(next(idata))

常见的用法是使用for循环对其进行遍历

for i, data in enumerate(dl):
    print(i,data)
    # 为了节约空间,这里只循环一遍
    break

至此,我们已经可以通过dataset定义数据集,并使用DataLorder载入和遍历数据集。

二、torchvision包

torchvision 是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程中最后的pip install torchvision 就是安装这个包。
torchvision已经预先实现了常用图像数据集,包括前面使用过的CIFAR-10,ImageNet、COCO、MNIST、LSUN等数据集,可通过torchvision.datasets方便的调用。

  • 这里总结一下torchvision已经预装的数据集:
数据集名称
MNIST
COCO
CIFAR-10
ImageNet
Captions
Detection
LSUN
ImageFolder
Imagenet-12
STL10
SVHN
PhotoTour

PyTorch中自带的数据集由2个上层api提供,分别是torchvision和torchtext

  • torchvision提供了对图像数据处理的相关数据和api
    • 数据位置:torchvision.datasets;例如:torchvision.datasets.MNIST
  • torchtext提供了对文本数据处理的相关数据和api
    • 数据位置:torchtext.datasets;例如:torchtext.datasets.IMDB

下面我们做一个简单的演示

  • 首先,我们要引入torchvision包
import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录
                                      train=True,  # 表示是否加载数据库的训练集,false的时候加载测试集
                                      download=True, # 表示是否自动下载 MNIST 数据集
                                      transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理

2.1 torchvision.models

torchvision不仅提供了常用的图像数据集,而且还提供了一些训练好的网络模型,可以加载之后直接使用,或者继续进行迁移学习。torchvision.models模块的子模块中包含以下模型:

网络模型
AlexNet
VGG
ResNet
SqueezeNet
DenseNet

我们直接可以使用训练好的模型,当然这个与datasets相同,都是需要从服务器下载的。

  • 首先,我们需要导入torchvision.models
import torchvision.models as models
  • 直接使用
resnet18 = models.resnet18(pretrained=True)

2.2 torchvision.tranforms

transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强

  • 首先,我们需要引入torchvision.tranforms,然后做一个简单的演示
from torchvision import transforms as transforms
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  #先四周填充0,在把图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
    transforms.RandomRotation((-45,45)), #随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])

肯定有人会问:(0.485, 0.456, 0.406), (0.2023, 0.1994, 0.2010) 这几个数字是什么意思?
官方的这个帖子有详细的说明: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21 这些都是根据ImageNet训练的归一化参数,可以直接使用,我们认为这个是固定值就可以。
到这里,我们已经完成了PyTorch的基本内容介绍。

参考文献

https://github.com/zergtant/pytorch-handbook/blob/master/chapter2

猜你喜欢

转载自blog.csdn.net/dongjinkun/article/details/113869697