pytorch torchvision study

torchvision 模块

torchvision是独立于pytorch的关于图像操作的工具库,主要包含了如下4个子模块或包:

  • datasets
  • utils
  • transforms
  • models

1、datasets

torchvision.datasets包含如下数据集,可以下载和加载

  • MNIST
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10
  • SVHN 
  • PhotoTour
from torchvision import datasets
​​​​​​​train_dataset = datasets.MNIST(root='./data', train=True,
                               transform=transforms.ToTensor(),
                               download=True)

此操作便可下载MNIST的训练数据集,

数据集有 API: - __getitem__ - __len__ 他们都是 torch.utils.data.Dataset的子类。因此, 他们可以使用torch.utils.data.DataLoader里的多线程 (python multithreading) 。

例如:

torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)

2、utils

utils主要提供了两个方法:

  • make_grid  
  • save_image
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False) 
将输入的minbatch_size图片转换成一张大的网格图片

torchvision.utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False) 
将输入的图片保存,如果输入的是minbatch_size图片,先make_grid转换成大的网格图再保存

3、transforms

了方便进行数据的操作,pytorch团队提供了一个torchvision.transforms包,我们可以用transforms进行以下操作:

PIL.Image/numpy.ndarray与Tensor的相互转化;

归一化;

对PIL.Image进行裁剪、缩放等操作。

通常,在使用torchvision.transforms,我们通常使用transforms.Compose将transforms组合在一起。

transform = transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                          std = [ 0.229, 0.224, 0.225 ]),
])
  • transforms.ToTensor() :把shape=(H x W x C)的像素值范围为[0, 255]的PIL.Image或者numpy.ndarray转换成shape=(C x H x W)的像素值范围为[0.0, 1.0]的torch.FloatTensor。
  • transforms.Normalize(mean,std) : 此转换类作用于torch.tensor,给定均值(R, G, B)和标准差(R, G, B),用公式channel = (channel - mean) / std进行规范化。

4、models

torchvision.models包含下列常用网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型

  • AlexNet: AlexNet variant from the “One weird trick” paper.
  • VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
  • ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
  • SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
import torchvision
model = torchvision.models.resnet50(pretrained=True)

这样就导入了resnet50的预训练模型了,

如果只需要网络结构,不需要用预训练模型的参数来初始化

model = torchvision.models.resnet50(pretrained=False)

如果要导入densenet模型也是同样的道理,比如导入densenet169,且不需要是预训练的模型,

model = torchvision.models.densenet169(pretrained=False)

由于预训练参数默认是假,所以等价于

model = torchvision.models.densenet169()


 

猜你喜欢

转载自blog.csdn.net/qq_42527487/article/details/84943369