lr_scheduler.StepLR adjusts the learning rate mechanism
Adjustments can torch.optim.lr_scheduler
be made with the help of classes; torch.optim.lr_schedule
the r module provides some methods epoch
for adjusting the learning rate based on the number of training times . (learning rate)
Under normal circumstances, we will epoch
gradually reduce the learning rate as the value increases to achieve better training results.
An adjustment strategy mechanism is introduced below: StepLR mechanism;
1.torch.optim.lr_scheduler.StepLR
Function prototype:
class torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
-
Update process:
Adjust the learning rate at equal intervals, the adjustment multiple is gamma times, and the adjustment interval is step_size. The interval unit is step. It should be noted that step usually refers to epoch, not iteration.
-
parameter:
- optimizer (Optimizer): the optimizer to change the learning rate
- step_size (int): Update parameters once every training step_size epochs
- gamma (float): update the multiplication factor of lr
- ast_epoch (int): The index of the last epoch. If training is interrupted after many epochs and training continues, this value is equal to the epoch of the loaded model. The default is -1, which means starting training from scratch, that is, starting from epoch=1
Example program:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import itertools
initial_lr = 0.1
class model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3)
def forward(self, x):
pass
net_1 = model()
optimizer_1 = torch.optim.Adam(net_1.parameters(), lr = initial_lr)
scheduler_1 = StepLR(optimizer_1, step_size=3, gamma=0.1)
print("初始化的学习率:", optimizer_1.defaults['lr'])
for epoch in range(1, 11):
# train
optimizer_1.zero_grad()
optimizer_1.step()
print("第%d个epoch的学习率:%f" % (epoch, optimizer_1.param_groups[0]['lr']))
scheduler_1.step()
operation result:
初始化的学习率: 0.1
第1个epoch的学习率:0.100000
第2个epoch的学习率:0.100000
第3个epoch的学习率:0.100000
第4个epoch的学习率:0.010000
第5个epoch的学习率:0.010000
第6个epoch的学习率:0.010000
第7个epoch的学习率:0.001000
第8个epoch的学习率:0.001000
第9个epoch的学习率:0.001000
第10个epoch的学习率:0.000100