莫烦pytorch(7)——保存提取

1.保存网络

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)# x data (tensor), shape=(100, 1)
y=x.pow(2)+0.2*torch.rand(x.size())# noisy y data (tensor), shape=(100, 1)

def save():
    #建网络
    net1=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)

    )
    optimizer=torch.optim.SGD(net1.parameters(),lr=0.5)
    loss_func=torch.nn.MSELoss()

    #训练
    for t in  range(100):
        prediction=net1(x)
        loss=loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    torch.save(net1.state_dict,"net_params.pkl")# 只保存网络中的参数 (速度快, 占内存少)
    torch.save(net1,"net.pkl")							#保存整个网络,占用的内存大

所以相应的恢复也有两种

1.恢复全网络

这种方式很简单,不需要在构造一个网络,直接net2=torch.load("net.pkl")(名字要和保存的一致)

2.恢复参数

这种方式则需要构建一个层数,神经元数一致的网络

def restore_params():
    # 新建 net2
    net2=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

之后在net2.load_state_dict(torch.load("net_params.pkl"))即可。

猜你喜欢

转载自blog.csdn.net/qq_42738654/article/details/87968672
今日推荐