【PyTorch】Torchvision

三、Torchvision

PyTorch官网:https://pytorch.org

1、Dataset

数据集描述:https://www.cs.toronto.edu/~kriz/cifar.html

数据集使用说明:

CIFAR10数据集:https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10

参数说明:

  • root:数据集存放位置
  • train:True(训练集)、False(测试集)
  • transform:变化
  • target_transform:target变化
  • download:是否下载

基本使用:

import torchvision

train_set = torchvision.datasets.CIFAR10(root="../data", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="../data", train=False, download=True)

print(test_set[0])
print(test_set.classes)

img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F0220>, 3)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
<PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F00D0>
3
cat

转为Tensor类型: 并使用TensorBoard显示

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=False, download=True)

writer = SummaryWriter("logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()

2、DataLoader

介绍:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

参数说明:

  • batch_size:每批要加载多少个样品(默认:1)
  • shuffle:True(重新洗牌),(默认:False)
  • num_workers:使用多少个子进程来加载数据,(默认:0 表示主进程)
  • drop_last:是否舍去最后(除不尽的)

2.1 test_data

import torchvision
from torch.utils.data import DataLoader

# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())

# 测试集第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
torch.Size([3, 32, 32]) # 3通道 32 * 32
3

2.2 test_loader

import torchvision
from torch.utils.data import DataLoader

# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

# 测试集第一张图片及target
# img, target = test_data[0]
# print(img.shape)
# print(target)

# test_loader
for data in test_loader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)
torch.Size([4, 3, 32, 32]) # 4张 3通道 32 * 32
tensor([1, 2, 0, 8]) # 4张图片的target糅合在一起
...
...

注意:target[1, 2, 0, 8]并不是按序采样,而是随机的!

2.3 drop_last

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

# batch_size=64
writer = SummaryWriter("logs")
step = 0
for data in test_loader:
    imgs, targets = data
    writer.add_images("test_data", imgs, step)
    step += 1

writer.close()

注意:最后一次采样只有16张图像,这是因为参数drop_last=False

当不满足每一次都取一定值的图片时,可以显示真实剩下的或者直接舍去(drop_last=True)。

当我们设置为drop_last=True时,就会舍去最后一组采样:

2.4 shuffle

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)

# shuffle=False
writer = SummaryWriter("logs")

for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        writer.add_images("Epoch:{}".format(epoch), imgs, step)
        step += 1

writer.close()

注意:两者采样完全相同,如果想要 “洗牌”,应设置shuffle=True

猜你喜欢

转载自blog.csdn.net/m0_70885101/article/details/127897302
今日推荐