PyTorch——Dataloader使用

一、Dataloader是啥

前面我在写PyTorch的第一篇文章里讲过Dataset是啥,Dataset就是将数据集分类,并且分析出这些数据集它的位置哪、大小多少、这个数据集一共有多少数据......等等信息

那么把Dataset比作一副扑克牌,那么如果你就让这副牌放在桌子那不去取牌,那你怎么打牌?Dataloader就是做【取牌】这个操作,就是去【读取数据】

二、使用DataLoader

首先先看一下官方文档对于DataLoader是怎么使用的:torch.utils.data — PyTorch 2.4 documentation

其中框住的解释的是常用的参数变量的作用解释

用一些例子结合tensorboard,直观地生动地解释一下

batch_size参数】:一次读取几个数据

drop_last参数】:最后一次读取,剩余数据不足【batch_size】时,要不要舍去

shuffle参数】:当多轮读取的时候,图片顺序是否一样,False是顺序一样

代码编写:导包(torchvision为了dataset,DataLoader则来自torch.utils.data)

然后先用dataset把数据集获取到,这里我用的是下载好的pytorch内置数据集CIFAR10,你们也可以用自定义数据集,注意语法区别就行

然后用DataLoader,设置好参数配置

import torchvision
from torch.utils.data import DataLoader

# 用dataset获取pytorch的内置数据集(我已经下载好,而且选用测试数据集)
test_dataset = torchvision.datasets.CIFAR10("./dataset2", train=False, transform=torchvision.transforms.ToTensor())

# 然后用DataLoader读取,并设置好参数(上面例子里没讲到的参数,你就当默认这么写就好了,我也不知道)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

三、结合transforms、tensorboard

语法都是之前学过的,直接创建SummaryWriter( )对象,指定图像文件生成在哪个文件夹;

然后遍历整个DataLoader返回的数据,返回的是一个列表;

每次循环,提取出每个元素里的【img】跟【target】,【img】就是tensorboard的【.add_images()】所需要的图像,另外step跟着遍历递增就行

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 用dataset获取pytorch的内置数据集(我已经下载好,而且选用测试数据集)
test_dataset = torchvision.datasets.CIFAR10("./dataset2", train=False, transform=torchvision.transforms.ToTensor())

# 然后用DataLoader读取,并设置好参数(上面例子里没讲到的参数,你就当默认这么写就好了,我也不知道)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

step = 0
write = SummaryWriter("DataLoader_logs")
for item in test_loader:
    img, target = item
    # print(img.shape)
    # print(target)

    # 利用tensorboard生成图像
    # 一定一定要注意!!是.add_images不是.add_image!不能漏了s
    write.add_images("dataloader", img, step)
    step += 1

write.close()

下一篇讲神经网络

猜你喜欢

转载自blog.csdn.net/m0_73991249/article/details/141380543