Computer vision is depth study of the most important class of applications, in order to facilitate the application of the researchers, pytorch specifically developed a visual toolkit torchvision.
Torchvision be installed through pip install.
torchvision mainly contains the following three parts:
Load Model
- models: provide in-depth learning in a variety of classic and network structure and the trained models, including Alex-Net, VGG series, ResNet series, Inception series.
- datasets: downloading provide common data set, are inherited torch.utils.data.Dataset design, including MNIST, CIFAR10 / 100, ImageNet, COCO like.
- transform: providing common data pre-processing operation, including operation and PIL Image object Tensor
from torchvision import models
from torch import nn
#加载预训练模型,如果不存在会下载
#预训练的模型保存在~/.torch/models/下面
resnet34 = models.resnet34(pretrained=True,num_classes=1000)
#修改最后的全连接层为10分类问题(默认是ImageNet上的1000分类)
resnet34.fc = nn.Linear(512,10)
import torch as t
from torchvision import transforms as T
to_pil = T.ToPILImage()
to_pil(t.randn(3,128,128))
>>输出如下图所示
Loading
from torchvision import transforms as T
transform = T.Compose(
[
T.ToTensor(),
T.Normalize(mean=[0.5],std=[0.5])
])
from torchvision import datasets
#指定数据集路径为data,如果数据集不存在则进行下载
#通过train = False获取测试集
dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)
len(dataset)
>>10000
torchvision also provided common to two functions, one is make_grid, it will splice multiple images in a grid; the other is save_img, Tensor can be saved as a picture.
from torch.utils.data import DataLoader
from torchvision.utils import make_grid,save_image
from torchvision import transforms as T
to_img = T.ToPILImage()
dataloader = DataLoader(dataset,batch_size=16,shuffle=True)
dataiter = iter(dataloader)
imgs,label = (next(dataiter))
print(label)
img = make_grid(imgs,4)#拼成4*4网格图片
to_img(img)
>>tensor([9, 2, 8, 3, 5, 6, 0, 5, 6, 2, 2, 5, 6, 6, 7, 6])
to_img(imgs[4])
from PIL import Image
save_image(img,'a.png')
Image.open('a.png')