【深度学习】(7)--神经网络之保存最优模型

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset) # 总数据大小
    num_batches = len(dataloader) # 划分的小批次数量
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{
      
      (100*correct)}%,Avg loss:{
      
      test_loss}")

    # 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)
        print(model.state_dict().keys()) # 输出模型参数名称cnn
        torch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{
      
      (100*correct)}%,Avg loss:{
      
      test_loss}")

# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 2. 保存完整模型(w,b,模型cnn)
        torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)

epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):
    print(f"Epoch {
      
      t+1} \n-------------------------")
    train(train_dataloader,model,loss_fn,optimizer)
    test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

猜你喜欢

转载自blog.csdn.net/m0_74896766/article/details/142514868