pytorch XI: Computer Vision Toolkit: torchvision

Computer vision is depth study of the most important class of applications, in order to facilitate the application of the researchers, pytorch specifically developed a visual toolkit torchvision.

Torchvision be installed through pip install.

torchvision mainly contains the following three parts:

 

 

Load Model 

  • models: provide in-depth learning in a variety of classic and network structure and the trained models, including Alex-Net, VGG series, ResNet series, Inception series.
  • datasets: downloading provide common data set, are inherited torch.utils.data.Dataset design, including MNIST, CIFAR10 / 100, ImageNet, COCO like.
  • transform: providing common data pre-processing operation, including operation and PIL Image object Tensor
from torchvision import models
from torch import nn

#加载预训练模型,如果不存在会下载
#预训练的模型保存在~/.torch/models/下面
resnet34 = models.resnet34(pretrained=True,num_classes=1000)

#修改最后的全连接层为10分类问题(默认是ImageNet上的1000分类)
resnet34.fc = nn.Linear(512,10)
import torch as t
from torchvision import transforms as T

to_pil = T.ToPILImage()
to_pil(t.randn(3,128,128))
>>输出如下图所示

 

Loading

from torchvision import transforms as T
transform = T.Compose(
[
    T.ToTensor(),
    T.Normalize(mean=[0.5],std=[0.5])
])

from torchvision import datasets
#指定数据集路径为data,如果数据集不存在则进行下载
#通过train = False获取测试集
dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)

len(dataset)
>>10000

 

torchvision also provided common to two functions, one is make_grid, it will splice multiple images in a grid; the other is save_img, Tensor can be saved as a picture.

from torch.utils.data import DataLoader
from torchvision.utils import make_grid,save_image
from torchvision import transforms as T
to_img = T.ToPILImage()

dataloader = DataLoader(dataset,batch_size=16,shuffle=True)

dataiter = iter(dataloader)
imgs,label = (next(dataiter))
print(label)
img = make_grid(imgs,4)#拼成4*4网格图片
to_img(img)

>>tensor([9, 2, 8, 3, 5, 6, 0, 5, 6, 2, 2, 5, 6, 6, 7, 6])

to_img(imgs[4])

from PIL import Image
save_image(img,'a.png')
Image.open('a.png')

Guess you like

Origin blog.csdn.net/qq_24946843/article/details/89452118