Pytorch学习笔记:LambdaLR——自定义学习率变化器

Pytorch学习笔记:LambdaLR——自定义学习率变化器

torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose='deprecated')

功能:

  将每个参数的学习率设置为初始的lr乘以一个权重系数factor,用于调整学习率大小,其中权重系数factor由函数lr_lambda得到,这里可以为每个层设置不同的学习率调整策略。

输入:

  • optimizer:优化器;
  • lr_lambda:给定epoch或者,传入函数或list列表;
  • last_epoch:当前的epoch,默认-1;
  • verbose:如果设为True,则每次学习率更新都会输出一条消息(即将弃用,查看学习率可通过调用get_last_lr()实现);

常用方法:

  • get_last_lr():返回当前的学习率

  • state_dict():提取__dict__中的数据(不包括optimizer),如果lr_lambda是一个可调用的对象时,可以被提取,如果是函数或者lambda时,则不会被提取,会得到None

  • load_state_dict(state_dict):加载参数;

代码案例

  对模型中所有参数都使用相同的学习率调整策略,学习率权重因子计算方法如下:
l r = α e p o c h ∗ b a s e _ l r lr=\alpha^{epoch} * base\_lr lr=αepochbase_lr

from torch.optim.lr_scheduler import LambdaLR
from torch.optim import SGD
from torchvision import models


def lambda_lr(epoch, alpha=0.99):
    return alpha ** epoch


model = models.resnet50()
optimizer = SGD(model.parameters(), lr=1e-3)
our_scheduler = LambdaLR(optimizer, lambda_lr)
last_lr = our_scheduler.get_last_lr()

for i in range(100):
    our_scheduler.step()
    last_lr = our_scheduler.get_last_lr()
    print(last_lr)

  对不同的参数层使用不同的学习率调整策略,这里对resnet50的特征提取层和全连接层使用不同的学习率下降策略,其中特征提取层下降速度要快于全连接层。

  首先在定义优化器时,需要将两组参数以不同的键值对传入优化器中,在定义lr_lambda时需要传入两种变化策略(以列表格式传入),注意顺序是一一对应的

from torch.optim.lr_scheduler import LambdaLR
from torch.optim import SGD
from torchvision import models


def fc_lambda_lr(epoch, alpha=0.99):
    return alpha ** epoch


def feature_lambda_lr(epoch, alpha=0.88):
    return alpha ** epoch


model = models.resnet50()
feature_params = []
fc_params = []
for name, param in model.named_parameters():
    if 'fc' in name:
        fc_params.append(param)
    else:
        feature_params.append(param)

optimizer = SGD([
    {
    
    'params': feature_params},
    {
    
    'params': fc_params}
], lr=1e-3)

our_scheduler = LambdaLR(optimizer, lr_lambda=[feature_lambda_lr, fc_lambda_lr])
last_lr = our_scheduler.get_last_lr()

for i in range(100):
    our_scheduler.step()
    last_lr = our_scheduler.get_last_lr()
    print(last_lr)

官方文档

LambdaLR:https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.LambdaLR.html#lambdalr

猜你喜欢

转载自blog.csdn.net/qq_50001789/article/details/136033639