DataLoader是PyTorch中的一种数据类型,在PyTorch架构中训练或者验证模型经常要使用它,那么怎么生成以及使用这样的数据类型?
一、参数设置
torch.utils.data.DataLoader(
dataset #数据加载
batch_size = 1 #批处理样本大小
shuffle = False #是否在每一轮epoch打乱样本顺序
sampler = None #指定数据加载中使用的索引/键的序列
batch_sampler = None #和sampler类似
num_workers = 0 #是否进行多进程加载数据设置
collate_fn = None #是否合并样本列表以形成一小批Tensor
pin_memory = False #如果True,数据加载器会在返回之前将Tensors复制到CUDA固定内存
drop_last = False #True若数据集大小不能被batch_size整除,则删除最后一个不完整的批处理。
timeout = 0 #如果为正,则为从工作人员收集批处理的超时值
worker_init_fn = None )
具体可参考官方文档。
1、dataset:(数据类型 Dataset)
输入的数据类型,也是最重要的参数,它表示要加载数据的数据集对象。
2、batch_size:(数据类型 int)
批处理样本的大小,默认为1。
3、shuffle:(数据类型 bool)
在每轮迭代训练时是否将数据洗牌。默认设置为False。设置为True则是在每一轮中,输入数据的顺序将被打乱,这是为了使数据更有独立性,训练的时候一般都设置为True,若输入数据是有序的,就不要设置成True了。
4、collate_fn:(数据类型 callable可调用对象)
将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。
5、sampler:(数据类型 Sampler)
采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。
6、num_workers:(数据类型 Int)
子进程数量,默认是0。使用多少个子进程来加载数据。0 就是使用主进程来加载数据。注意:这个数字必须是大于等于0的,该值的设置应该量内存大小而为。
7、pin_memory:(数据类型 bool)
内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。
8、drop_last:(数据类型 bool)
丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
9、timeout:(数据类型 numeric)
超时值,默认为0。是用来设置数据读取的超时时间,超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。
二、实际应用
import torch
from torch.utils.data import Dataset, DataLoader
#---------------预处理-----------------
transform = transforms.Compose([
transforms.Resize((224, 224), 2),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
#--------------数据加载----------------
trainset = torchvision.datasets.CIFAR10(root='./data',
train=True,
download=False,
transform=transform)
# torch.utils.data.DataLoader
trainloader = DataLoader(dataset=trainset,
batch_size=32,
shuffle=True,
num_workers=0)
for epoch in range(100):
running_loss = 0.0
batch_size = 32
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)