torchvision.datasets
由于MNIST数据集太简单,简单的网络就可以达到99%以上的top one准确率,也就是说在这个数据集上表现较好的网络,在别的任务上表现不一定好。因此zalando research的工作人员建立了fashion mnist数据集,该数据集由衣服、鞋子等服饰组成,包含70000张图像,其中60000张训练图像加10000张测试图像,图像大小为28x28,单通道,共分10个类,如下图,每3行表示一个类。
所以我们通过torchvison
来处理FashionMNIST
数据集:
import torch
import torchvision
import torchvision.transforms as transforms
train_set = torchvision.datasets.FashionMNIST(
root = './data/FasionMNIST', # 将数据保存在本地什么位置
train=True, # 我们希望数据用于训练集,其中6万张图片用作训练数据,1万张图片用于测试数据
download=True, # 如果目录下没有文件,则自动下载
transform=transforms.Compose([
transforms.ToTensor()
]) # 我们将数据转为Tensor类型
)
这样我们就完成了FashionMNIST数据的提取和转换。
如果这个过程中报错:ImportError: IProgress not found. Please update jupyter and ipywidgets.
。一般是jupyter的版本有些低了,可能是你默认的环境,所以重装以下就好了:
# 可以先用你的环境 conda activate xx
# 卸载jupyter:
pip install --upgrade jupyter
访问单独某个训练数据:
torchvision.dataloader
dataloader使我们能够访问数据并提供查询功能。
train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)
通过train_loader
方式得到的batch包含图像的张量是4维的张量,形状是[10, 1, 28, 28],这告诉我们有10个图像,他们都有1个单独的颜色通道,高度宽度都是28;对于包含标签的张量,他的长度是10,每10个图像为一批数据。
现在让我们看看如何使用torchvison.utils.make_grid
函数一次性的画出整批图像:
我们可以看到,我们已经使用torchvision.utils.make_grid
函数创建了一个网络,我们把图像张量作为第一个参数,nrow=10
这样我们所有的图像就会沿着一行显示,nrow
参数指定每一行的图像数量,因为我们的batch_size=10
,这就给我们了一排图像,我们使用np.transpose(grid, (1,2,0))
,这样轴就满足了图像的功能需要的规格。
现在我们知道了dataset
和dataloader
之间如何交互的了。现在试试如何批量处理数据: