Pytorch——保存训练好的模型参数

1.前言

训练好了一个模型, 我们当然想要保存它, 留到下次要用的时候直接提取直接用,下面我将来讲如何存储训练好的模型参数

2.torch.save(保存模型)

首先,先搭建一个神经网络

import torch
from torch import nn
import matplotlib.pyplot as plt
torch.manual_seed(11)    # 使每次得到的随机数是固定的。但是如果不加上torch.manual_seed这个函数调用的话,打印出来的随机数每次都不一样


x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # [100] -> [100,1]
y = x.pow(2) + 0.5*torch.rand(x.size())  # y的形状与x一样


def make_and_save_model():
    
    network = torch.nn.Sequential(
        torch.nn.Linear(1, 8),
        torch.nn.ReLU(),
        torch.nn.Linear(8, 1)
    )
    optimizer = torch.optim.SGD(network.parameters(), lr=0.3)   #优化器
    criterion = torch.nn.MSELoss()     #损失函数

    # 训练
    for i in range(200):
        prediction = network(x)      #数据放入模型后得到预测值
        loss = criterion(prediction, y)    #计算预测值与真实值之间的误差
        optimizer.zero_grad()       #清空梯度
        loss.backward()           #误差反向传播
        optimizer.step()          #更新参数
    torch.save(network, 'network.pth')  # 保存整个网络
    torch.save(network.state_dict(), 'network_params.pth')   # 只保存网络中的参数
    
    plt.figure(1, figsize = (10,3))
    plt.subplot(131)
    plt.title('network')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
    plt.pause(1)

3.torch.load整个网络

这种方式将会提取整个神经网络, 网络大的时候可能会比较慢.

def load_whole_model():

    network_whole = torch.load('network.pth')
    prediction = network_whole(x)
    
    plt.figure(1, figsize = (10,3))
    plt.subplot(132)
    plt.title('network_whole')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)
    plt.pause(1)

4.torch.load网络参数(只提取参数)

这种方式将会提取所有的参数, 然后再放到你的新建网络中

def load_only_params():
    
    network_params = torch.nn.Sequential(
        torch.nn.Linear(1, 8),
        torch.nn.ReLU(),
        torch.nn.Linear(8, 1)
    )

    network_params.load_state_dict(torch.load('network_params.pth'))
    prediction = network_params(x)
    
    plt.figure(1, figsize = (10,3))
    plt.subplot(133)
    plt.title('network_params')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'yo' , lw = 5)

5.调用三个函数

会看到加载后的模型画出的图是一样的,说明模型的参数正确加载了。

make_and_save_model()
load_whole_model()
load_only_params()

在这里插入图片描述

发布了122 篇原创文章 · 获赞 341 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/weixin_37763870/article/details/104815034