PyTorch torchvision包

torchvision包服务于PyTorch框架,包括了计算机视觉中一些流行的数据集网络模型以及常见的图片变换方法,主要由以下几部分构成:
torchvision.datasets: 一些加载数据的函数和常用的数据集接口
torchvision.models:包含常用的模型结构(含预训练模型)
torchvision.transforms:常见的图片变换,如裁剪、旋转等
torchvision.utils:其它一些有用的方法

利用torchvision.datasets接口可以得到许多种类的数据集,可以传给DataLoader做进一步处理,使用这些数据集的API都差不多,以使用MNIST数据集为例,来看一下使用的流程

导入需要的包或模块:

import torch
import torchvision.transforms as transforms
import torchvision.datasets as dataset

MNIST的文档描述:

root-就是存放数据的文件路径
train-如果为True,那么得到训练集,否则得到测试集
download-True表示需要从网上下载这个数据集,如果已经下载了,那么直接加载
transform-对图片进行一些处理
示例:

mnist_train_data = datasets.MNIST('../MNIST', train=True, download=True, transform = 
                                tranforms.Compose([
                                    torch.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,))
                                 ]))

对图片的tranforms有很多种,根据需要选择,他们可以利用torchvision.tranforms.Compose(transforms)把这些操作串在一起,就像上面例子一样,形式是:

transforms.Compose([
    transforms.CenterCrop(10),
    transforms.ToTensor(),
])

更多的细节以后遇到再加以补充。。。

参考:

  1. https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.5_fashion-mnist
  2. PyTorch官方文档

猜你喜欢

转载自www.cnblogs.com/patrolli/p/11870242.html
今日推荐