6.2 动态调整学习率⚓︎
学习率的选择是深度学习中一个困扰人们许久的问题,学习速率设置过小,会极大降低收敛速度,增加训练时间;学习率太大,可能导致参数在最优解两侧来回振荡。但是当我们选定了一个合适的学习率后,经过许多轮的训练后,可能会出现准确率震荡或loss不再下降等情况,说明当前学习率已不能满足模型调优的需求。此时我们就可以通过一个适当的学习率衰减策略来改善这种现象,提高我们的精度。这种设置方式在PyTorch中被称为scheduler,也是我们本节所研究的对象。
经过本节的学习,你将收获:
- 如何根据需要选取已有的学习率调整策略
- 如何自定义设置学习调整策略并实现
6.2.1 使用官方scheduler⚓︎
- 了解官方提供的API
在训练神经网络的过程中,学习率是最重要的超参数之一,作为当前较为流行的深度学习框架,PyTorch已经在torch.optim.lr_scheduler
为我们封装好了一些动态调整学习率的方法供我们使用,如下面列出的这些scheduler。
lr_scheduler.LambdaLR
lr_scheduler.MultiplicativeLR
lr_scheduler.StepLR
lr_scheduler.MultiStepLR
lr_scheduler.ExponentialLR
lr_scheduler.CosineAnnealingLR
lr_scheduler.ReduceLROnPlateau
lr_scheduler.CyclicLR
lr_scheduler.OneCycleLR
lr_scheduler.CosineAnnealingWarmRestarts
lr_scheduler.ConstantLR
lr_scheduler.LinearLR
lr_scheduler.PolynomialLR
lr_scheduler.ChainedScheduler
lr_scheduler.SequentialLR
这些scheduler都是继承自_LRScheduler
类,我们可以通过help(torch.optim.lr_scheduler)
来查看这些类的具体使用方法,也可以通过help(torch.optim.lr_scheduler._LRScheduler)
来查看_LRScheduler
类的具体使用方法。
- 使用官方API
关于如何使用这些动态调整学习率的策略,PyTorch
官方也很人性化的给出了使用实例代码帮助大家理解,我们也将结合官方给出的代码来进行解释。
# 选择一种优化器
optimizer = torch.optim.Adam(...)
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler....
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 进行训练
for epoch in range(100):
train(...)
validate(...)
optimizer.step()
# 需要在优化器参数更新之后再动态调整学习率
# scheduler的优化是在每一轮后面进行的
scheduler1.step()
...
schedulern.step()
注:
我们在使用官方给出的torch.optim.lr_scheduler
时,需要将scheduler.step()
放在optimizer.step()
后面进行使用。
6.2.2 自定义scheduler⚓︎
虽然PyTorch官方给我们提供了许多的API,但是在实验中也有可能碰到需要我们自己定义学习率调整策略的情况,而我们的方法是自定义函数adjust_learning_rate
来改变param_group
中lr
的值,在下面的叙述中会给出一个简单的实现。
假设我们现在正在做实验,需要学习率每30轮下降为原来的1/10,假设已有的官方API中没有符合我们需求的,那就需要自定义函数来实现学习率的改变。
def adjust_learning_rate(optimizer, epoch):
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
有了adjust_learning_rate
函数的定义,在训练的过程就可以调用我们的函数来实现学习率的动态变化
def adjust_learning_rate(optimizer,...):
...
optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
train(...)
validate(...)
adjust_learning_rate(optimizer,epoch)
本节参考⚓︎
创建日期: November 30, 2023