1.保存网络
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)# x data (tensor), shape=(100, 1)
y=x.pow(2)+0.2*torch.rand(x.size())# noisy y data (tensor), shape=(100, 1)
def save():
#建网络
net1=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer=torch.optim.SGD(net1.parameters(),lr=0.5)
loss_func=torch.nn.MSELoss()
#训练
for t in range(100):
prediction=net1(x)
loss=loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net1.state_dict,"net_params.pkl")# 只保存网络中的参数 (速度快, 占内存少)
torch.save(net1,"net.pkl") #保存整个网络,占用的内存大
所以相应的恢复也有两种
1.恢复全网络
这种方式很简单,不需要在构造一个网络,直接net2=torch.load("net.pkl")
(名字要和保存的一致)
2.恢复参数
这种方式则需要构建一个层数,神经元数一致的网络
def restore_params():
# 新建 net2
net2=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
之后在net2.load_state_dict(torch.load("net_params.pkl"))
即可。