pytorch in resnet how fast loading pre-training model provided by the official

In the process of doing neural network structures, pytorch resnet often used as the backbone, in particular resnet50, the network configuration such as the following

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision import models

class base_resnet(nn.Module):
    def __init__(self):
        super(base_resnet, self).__init__()
        self.model = models.resnet50(pretrained=True)
        #self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
        self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)

        # x = x.view(x.size(0), x.size(1))
        return x

The network structure corresponds resnet50 inherits all of the parameters, but is in the forward, changing the data transmission process, not after the final deployment characteristics and linear classification. In following this line of code is equivalent to calling the resnet50 network pytoch defined in, and automatically download and load the trained network parameters, if adjusted to pretrained = False, the parameter is not loaded trained, but random assignment parameters. But I ran this type of code on the server when found, when I re-run a program, if set to True will re-download resnet50 trained parameters, but because sometimes the network particularly bad, cause I download a foundation resnet50 would cost me a long time, so I wanted to be able to download the parameters of this resnet50 advance good use of the time to load it directly. Of course, it is enabled.

self.model = models.resnet50(pretrained=True)

We can use our structure, to download the model corresponding to the address corresponding to the local common resnet following address:

 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',

To download it, and then put into the model and net.py model files in the same directory folder below, and then use the following code can avoid re-download every time the issue of the model.

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

 

Published 36 original articles · won praise 11 · views 6539

Guess you like

Origin blog.csdn.net/t20134297/article/details/103885879