文章
对于pytorch数据集的使用,示例代码如下:
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from torchvision import transforms
import torchvision
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset_transform = Compose([transforms.ToTensor()])
# 关于官方数据集的使用还是关键要看pytorch的官方文档
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=dataset_transform,download=True)
# 查看测试数据集中的第一个数据
# print(test_set[0])
# 查看测试数据集中的分类情况
# print(test_set.classes)
#
# 取出第一个数据中的图片(img)和分类结果(target)
# img,target = test_set[0]
# 查看图片数据的类型
# print(img)
# print(target)
# 输出类别
# print(test_set.classes[target])
# 查看图片
# img.show()
# 使用tensorboard显示tensor数据类型的图片
writer = SummaryWriter("logs")
for i in range(10):
# 取出数据中的图片(img)和分类结果(target)
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()
上述代码运行结果在tensorboard可视化:
代码train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
中常用参数讲解
root:根目录,存放数据集的位置
train:若为True,则划分为训练数据集,若为False,则划分为测试数据集
transform:指定输入数据集处理方式
download:若为True,则会将数据集下载到root指定的目录下,否则不会下载
官方文档对参数的解释:
-
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
-
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
-
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
-
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
-
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
扫描二维码关注公众号,回复: 14322896 查看本文章
注意
- 关于官方数据集的使用还是关键要看pytorch的官方文档
- 下载数据集的细节之处:知道下载链接(下载链接可以在源码中查看)之后可以不用使用代码下载了,使用迅雷来下载可能会更快。
- 要学会使用Pycharm中的
ctrl+p
和ctrl+alt
这两个快捷键 - pytorch官网
- pytorch官方数据集(下载数据集方法)