1、动态调整学习率以及保存学习率
最近在模型训练时,发现动态调整学习率时,如果训练中断,没有将学习率保存起来,下一次断点训练使用的还是初始学习率。
多步长SGD继续训练:在简单的任务中,我们使用固定步长(也就是学习率LR)进行训练,但是如果学习率lr设置的过小的话,则会导致很难收敛,如果学习率很大的时候,就会导致在最小值附近,总会错过最小值,loss产生震荡,无法收敛。所以这要求我们要对于不同的训练阶段使用不同的学习率,一方面可以加快训练的过程,另一方面可以加快网络收敛。
所以我们在保存网络中的训练的参数的过程中,还需要保存scheduler的state_dict,然后断点继续训练的时候恢复。
#恢复断点
RESUME = False
if RESUME:
path_checkpoint = "/home/sgyj/code/FrequecyTransformer/checkpoint/ckpt_best_55.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
net.load_state_dict(checkpoint['net']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
scheduler.load_state_dict(checkpoint['scheduler'])#恢复scheduler的state_dict
#保存断点
if (epoch != 0 and epoch % 5 == 0):
checkpoint = {
"net": net.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
'scheduler':scheduler.state_dict()
}
if not os.path.isdir("/home/sgyj/code/FrequecyTransformer/checkpoint"):
os.mkdir("/home/sgyj/code/FrequecyTransformer/checkpoint")
torch.save(checkpoint, '/home/sgyj/code/FrequecyTransformer/checkpoint/ckpt_best_%s.pth' % (str(epoch)))
# 每20个epoch保存一次模型
可以参考如下文章
https://blog.csdn.net/weixin_35698091/article/details/112429883
2、打印学习率
动态的打印学习率
print(optimizer.state_dict()['param_groups'][0]['lr'])