pytorch(一)pytorch&torchvision介绍与安装

torchvision介绍

torchvision 是PyTorch中专门用来处理图像的库。这个包中有四个大类。

torchvision.datasets
torchvision.models
torchvision.transforms
torchvision.utils

torchvision.datasets

torchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了一些图像的公开数据集。

MNISTCOCO
Captions
Detection
LSUN
ImageFolder
Imagenet-12
CIFAR
STL10
SVHN
PhotoTour
···

下载数据集方法

train_data = torchvision.datasets.MNIST(
    root="./mnist",  # 设置数据集的根目录
    train=True,  # 是否是训练集
    transform=trans,  # 对数据进行转换
    download=DOWNLOAD_MNIST
)
# 第二个参数是数据分块之后每一个块的大小,第三个参数是是否大乱数据
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(
    root="./mnist",
    train=False,  # 测试集,所以false
    transform=trans,
    download=DOWNLOAD_MNIST
)
test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)

torchvision.models

torchvision.models 中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用。

torchvision.models模块的子模块中包含以下模型结构。

AlexNet
VGG
ResNet
SqueezeNet
DenseNet

快速创建模型

# 快速创建一个权重随机初始化的模型
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

# 或使用 pretrained=True 来加载一个别人预训练好的模型
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)

torchvision.transforms

torchvision.transforms提供了一般的图像操作类,这个包中包含resize、crop等常见的data augmentation操作,基本上PyTorch中的data augmentation操作都可以通过该接口实现。

多个操作一起使用


# 对MNIST进行处理,初始的MNIST28*28,我们把它处理成96*96的torch.Tensor的格式
from torchvision import transforms as transforms
import torchvision
from torch.utils.data import DataLoader
 
# 图像预处理步骤
transform = transforms.Compose([
    transforms.Resize(96), # 缩放到 96 * 96 大小
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])
 
DOWNLOAD = True
BATCH_SIZE = 32
 
train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=DOWNLOAD)
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)
 
# print(len(train_dataset))
# print(len(train_loader))

torchvision 和 pytorch 安装

conda install pytorch cuda92 -c pytorch
conda install pytorch torchvision cuda92 -c pytorch

猜你喜欢

转载自blog.csdn.net/m0_45117053/article/details/104784116