多步长MultiStepLR动态调整学习率断点的保存与恢复

1、动态调整学习率以及保存学习率

最近在模型训练时,发现动态调整学习率时,如果训练中断,没有将学习率保存起来,下一次断点训练使用的还是初始学习率。

多步长SGD继续训练:在简单的任务中,我们使用固定步长(也就是学习率LR)进行训练,但是如果学习率lr设置的过小的话,则会导致很难收敛,如果学习率很大的时候,就会导致在最小值附近,总会错过最小值,loss产生震荡,无法收敛。所以这要求我们要对于不同的训练阶段使用不同的学习率,一方面可以加快训练的过程,另一方面可以加快网络收敛。

所以我们在保存网络中的训练的参数的过程中,还需要保存scheduler的state_dict,然后断点继续训练的时候恢复。

#恢复断点
    RESUME = False
    if RESUME:
        path_checkpoint = "/home/sgyj/code/FrequecyTransformer/checkpoint/ckpt_best_55.pth"  # 断点路径
        checkpoint = torch.load(path_checkpoint)  # 加载断点

        net.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
        start_epoch = checkpoint['epoch']  # 设置开始的epoch
        scheduler.load_state_dict(checkpoint['scheduler'])#恢复scheduler的state_dict

#保存断点
  if (epoch != 0 and epoch % 5 == 0):
            checkpoint = {
    
    
                "net": net.state_dict(),
                'optimizer': optimizer.state_dict(),
                "epoch": epoch,
                'scheduler':scheduler.state_dict()
            }
            if not os.path.isdir("/home/sgyj/code/FrequecyTransformer/checkpoint"):
                os.mkdir("/home/sgyj/code/FrequecyTransformer/checkpoint")
            torch.save(checkpoint, '/home/sgyj/code/FrequecyTransformer/checkpoint/ckpt_best_%s.pth' % (str(epoch)))
        # 每20个epoch保存一次模型

可以参考如下文章
https://blog.csdn.net/weixin_35698091/article/details/112429883

2、打印学习率

动态的打印学习率

print(optimizer.state_dict()['param_groups'][0]['lr'])

猜你喜欢

转载自blog.csdn.net/weixin_44020747/article/details/119537807