Pytorch教程(十六):FashionMNIST数据集DataSet DataLoader

torchvision.datasets

由于MNIST数据集太简单,简单的网络就可以达到99%以上的top one准确率,也就是说在这个数据集上表现较好的网络,在别的任务上表现不一定好。因此zalando research的工作人员建立了fashion mnist数据集,该数据集由衣服、鞋子等服饰组成,包含70000张图像,其中60000张训练图像加10000张测试图像,图像大小为28x28,单通道,共分10个类,如下图,每3行表示一个类。
在这里插入图片描述
所以我们通过torchvison来处理FashionMNIST数据集:

import torch
import torchvision
import torchvision.transforms as transforms

train_set = torchvision.datasets.FashionMNIST(
    root = './data/FasionMNIST',  # 将数据保存在本地什么位置
    train=True,  # 我们希望数据用于训练集,其中6万张图片用作训练数据,1万张图片用于测试数据
    download=True,  # 如果目录下没有文件,则自动下载
    transform=transforms.Compose([
        transforms.ToTensor()
    ])  # 我们将数据转为Tensor类型
)

这样我们就完成了FashionMNIST数据的提取和转换。

如果这个过程中报错:ImportError: IProgress not found. Please update jupyter and ipywidgets.。一般是jupyter的版本有些低了,可能是你默认的环境,所以重装以下就好了:

# 可以先用你的环境 conda activate xx
# 卸载jupyter:
pip install --upgrade jupyter

访问单独某个训练数据:
在这里插入图片描述
在这里插入图片描述

torchvision.dataloader

dataloader使我们能够访问数据并提供查询功能。

train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)

在这里插入图片描述
通过train_loader方式得到的batch包含图像的张量是4维的张量,形状是[10, 1, 28, 28],这告诉我们有10个图像,他们都有1个单独的颜色通道,高度宽度都是28;对于包含标签的张量,他的长度是10,每10个图像为一批数据。

现在让我们看看如何使用torchvison.utils.make_grid函数一次性的画出整批图像:
在这里插入图片描述
我们可以看到,我们已经使用torchvision.utils.make_grid函数创建了一个网络,我们把图像张量作为第一个参数,nrow=10这样我们所有的图像就会沿着一行显示,nrow参数指定每一行的图像数量,因为我们的batch_size=10,这就给我们了一排图像,我们使用np.transpose(grid, (1,2,0)),这样轴就满足了图像的功能需要的规格。

现在我们知道了datasetdataloader之间如何交互的了。现在试试如何批量处理数据:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/vincent_duan/article/details/120754248
今日推荐