torchvision模块
不要担心自己的形象,只关心如何实现目标。——《原则》,生活原则
英文PyTorch文档的Libraries中有关于torchvision的介绍
实战讲解有以下几个部分
数据预处理部分
—数据增强
—数据预处理
网络模块设置
—加载预训练模型
> 模型训练、保存与测试(下一个文章里)
—选择性的保存
—读取模型并测试
在这个(torchvision)模块中有许多我们需要的功能,比如:torchvision.datasets(封装了一些常用的数据集 和 定义了数据存放的方法) , torchvision.models(包括了一些经典网络架构的实现,以及与训练模型) , torchvision.transfrom(数据预处理模块)
纯分类任务的数据类型可以参考ImageFolder模块。官网链接
当然不同的任务,数据集的构建可能是不一样的~~~
1 数据预处理
首先我们导入所需要的pip包,并定义好我们想用的数据集的路径
然后我们对图像进行 数据增强处理–ToTensor–Normalize
接下来我们构建数据集
目前我们的数据集里面全是1、2、3这种,我们可以读取json文件来把编号和名字对应起来
数据展示
- 注意:我们这里如果想要展示数据的话,必须将已经转换为temsor格式的数据还原为numpy格式,并且需要对标准化进行还原
- 其中transpose操作是把 C * H * W 还原为 H * W * C
- squeeze的作用是降维。具体可参考别人的博客
2 网络模型设置