pytorh学习笔记——cifar10(九)使用torhvision的标准resnet模型

 之前的demo都是模仿和简化了已有的模型,也可以直接调用orhvision的标准模型,代码将更加简单。

新建resnet18.py

import torch.nn as nn
from torchvision import models


class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        self.model = models.resnet18(pretrained=True)  # 调用torchvision.models中的resnet18
        self.num_ftrs = self.model.fc.in_features  # 获取全连接层的输入特征数
        self.model.fc = nn.Linear(self.num_ftrs, num_classes)  # 修改全连接层

    def forward(self, x):
        out = self.model(x)
        return out


def resnet18():
    return ResNet18()

 在之前的train.py脚本导入模型,并修改脚本中的net定义,改为:
net = resnet18().to(device),

即可运行开始训练,首次运行,会自动下载模型:

下载完之后就开始训练:

猜你喜欢

转载自blog.csdn.net/xulibo5828/article/details/143222540