1. 快速搭建方法
前面两篇文章我们用这种方式搭建神经网络:
class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x net1 = Net(1, 10, 1) # 这是我们用这种方式搭建的 net1
现在有另外一种方法:
net2 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) )
对比:
print(net1) """ Net ( (hidden): Linear (1 -> 10) (predict): Linear (10 -> 1) ) """ print(net2) """ Sequential ( (0): Linear (1 -> 10) (1): ReLU () (2): Linear (10 -> 1) ) """
(1)net2没有给层命名,而是编号;net1 在__init__中定义了名字
(2)net2 吧激励函数也纳入层中。在net1中,激励函数实际上是在 forward() 功能中才被调用。net1可以根据需要更加个性化你自己的前向传播过程, 比如(RNN). 如果你不需要七七八八的过程, 相信 net2 这种形式更适合你.
2.保存提取:
2.1 构建和保存:
torch.manual_seed(1) # reproducible # 假数据 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1) y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1) x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) def save(): # 建网络 net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_func = torch.nn.MSELoss() # 训练 for t in range(100): prediction = net1(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step()
torch.save(net1, 'net.pkl') # 保存整个网络 torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少)
2.2 提取整个网络
def restore_net(): # restore entire net1 to net2 net2 = torch.load('net.pkl') prediction = net2(x) #提取net2中的x print(net2) print(prediction)
...
2.3 只提取网络参数
需要先新建一个网络,将参数复制进去,时间比上面那个快些
def restore_params(): # 新建 net3 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) # 将保存的参数复制到 net3 net3.load_state_dict(torch.load('net_params.pkl')) prediction1 = net3(x) print(net3) print(prediction1)
2.4 结果显示
# 保存 net1 (1. 整个网络, 2. 只有参数) save() # 提取整个网络 restore_net() # 提取网络参数, 复制到新网络 restore_params()
3.结果:
注意Python对空格敏感。
参考链接:
https://morvanzhou.github.io/tutorials/machine-learning/torch/3-04-save-reload/
https://morvanzhou.github.io/tutorials/machine-learning/torch/3-03-fast-nn/