Pytorch学习笔记2--模型定义

版权声明:本样板的所有内容,包括文字、图片,均为原创。如有问题可以邮箱[email protected] https://blog.csdn.net/qq_29893385/article/details/84644478

保存模型的推荐方法

这主要有两种方法序列化和恢复模型。

第一种(推荐只保存和加载模型参数

torch.save(the_model.state_dict(), PATH)

然后:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二种保存和加载整个模型:

torch.save(the_model, PATH)

然后:

the_model = torch.load(PATH)

然而,在这种情况下,序列化的数据被绑定到特定的类和固定的目录结构,所以当在其他项目中使用时,或者在一些严重的重构器之后它可能会以各种方式break。

上面是官方文档给出的解释,可能容易理解但是落实到代码层面可能有点难以下手,所以下面给出一个具体的实例,展示一下如何使用torch.save()函数保存模型或模型参数,torch.load()加载模型.

import torch
from torch.autograd import  Variable
import torch.nn.functional as  F
import matplotlib.pyplot as plt

x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())

x,y = Variable(x),Variable(y)

def save():

    net = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1),

    )

    optimizer = torch.optim.SGD(net.parameters(),lr=0.5)
    loss_fun = torch.nn.MSELoss()

    for t in range(100):
        out = net(x)
        loss = loss_fun(out,y)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    torch.save(net,'net.pkl')
    torch.save(net.state_dict(),'net_params.pkl')

    plt.figure(1,figsize=(10,3))
    plt.subplot(131)
    plt.title('Net')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),out.data.numpy(),'green',lw=5)





def restore_net():
    net1 = torch.load('net.pkl')
    out = net1(x)


    plt.subplot(132)
    plt.title('Net1')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),out.data.numpy(),'green',lw=5)



def restore_netparams():
    net2 = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1),

    )
    net2.load_state_dict(torch.load('net_params.pkl'))
    out = net2(x)

    plt.subplot(133)
    plt.title('Net2')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),out.data.numpy(),'green',lw=5)

    plt.show()

save()

restore_net()

restore_netparams()

这里详细解释下:

这个代码是我观看莫烦视频教程时候手码的,很久之前了,还是没有理解太透彻....

整个网络首先定义了三个函数

save()      ###保存模型

restore_net()   ###载入整个模型

restore_netparams()  ##载入模型参数

在save()函数中首先通过nn.Sequential()快速创建基础网络,定义好optimizer和LOSS function后开始进行梯度清零,反向传播和梯度更新,100步之后保存模型

torch.save(net,'net.pkl')   ##保存整个模型
torch.save(net.state_dict(),'net_params.pkl')  ##保存整个模型参数

对于两种不同的方法定义不同函数进行模型的载入

对于整个模型的保存载入非常简单,直接使用torch.load(PATH),可以直接载入,但是不推荐这种方法,一是对于大量数据集训练迭代次数一般很多(50000...),所以模型一般会非常大,更重要的是这种保存的模型泛化性极其之差.

    net1 = torch.load('net.pkl')
    out = net1(x)

所以更为推荐下面这种方法

   net2 = torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1),

    )
    net2.load_state_dict(torch.load('net_params.pkl'))
    out = net2(x)

这种方法稍微复杂那么一点点,因为你之前只是保存的整个模型的参数,但是并没有给模型所以你需要调用的话,需要把模型框架copy过来,可以对比一下net和net2 没有区别,另外载入的时候指令也稍微长那么一点,一定要记住

net2.load_state_dict(torch.load('net_params.pkl'))

现在对这三行指令应该有点理解了吧 

torch.save(the_model.state_dict(), PATH)  ##保存模型

the_model = TheModelClass(*args, **kwargs)  ##给定网络
the_model.load_state_dict(torch.load(PATH))  ##载入模型

另外呢,既然提到了网络,之前在使用caffe的时候也利用可视化工具看过Alexnet,vgg16,resnet等网络的模型结构图,具体可以看下之前的博客CNN经典分类模型--AlexNet、VGG16、ResNet网络结构图

如果想用代码实现具体网络的搭建,这里推荐pytorch给定的models文件https://github.com/pytorch/vision/tree/master/torchvision/models

 

每一个py文件都是具体网络的编写,如果想要直接调用的话也非常简单,

import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

当然,如果你想使用pre-trained model ,只需要更改pretrained的布尔值为True就OK ;

import torchvision.models as models
#pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)

 这里给出官方中文的具体介绍,推荐去看下

最后呢,感谢莫烦大佬的视频教程,非常适用于小编这样的初学者入门,给个传送门,去B站投币吧  

莫烦PyTorch入门教程

猜你喜欢

转载自blog.csdn.net/qq_29893385/article/details/84644478