优化器的迁移⚓︎
我们已经将MMGeneration 1.x合并至MMagic。以下是针对MMGeneration中优化器的迁移事项。
在0.x版中,MMGeneration使用PyTorch自带的优化器,其只提供了通用参数优化,而在1.x版中,我们则使用了MMEngine提供的OptimizerWrapper
。
对比PyTorch自带的Optimizer
,OptimizerWrapper
可以支持如下功能:
OptimizerWrapper.update_params
在一个单一的函数中就实现了zero_grad
,backward
和step
- 支持梯度自动累积
- 提供一个名为
OptimizerWrapper.optim_context
的上下文管理器来封装前向进程,optim_context
会根据当前更新迭代数目来自动调用torch.no_sync
,在AMP(Auto Mixed Precision)训练中,autocast
也会在optim_context
中被调用。
对GAN模型,生成器和鉴别器采用不同的优化器和训练策略。要使GAN模型的train_step
函数签名和其它模型的保持一致,我们使用从OptimizerWrapper
继承下来的OptimWrapperDict
来封装生成器和鉴别器的优化器,为了便于该流程的自动化MMagic实现了MultiOptimWrapperContructor
构造器。如你想训练GAN模型,那么应该在你的配置中指定该构造器。
如下是0.x版和1.x版的配置对比
0.x版 | 1.x版 |
---|---|
|
|
注意,在1.x版中,MMGeneration使用
OptimWrapper
来实现梯度累加,这就会导致在0.x版和1.x版之间,discriminator_steps
配置(用于在多次更新鉴别器之后更新一次生成器的训练技巧)与梯度累加均出现不一致问题。
- 在0.x版中,我们在配置里使用
disc_steps
,gen_steps
和batch_accumulation_steps
。disc_steps
和batch_accumulation_steps
会根据train_step
的调用次数来进行统计(亦即dataloader中数据的读取次数)。因此鉴别器的一段连续性更新次数为disc_steps // batch_accumulation_steps
。且对于生成器,gen_steps
是生成器实际的一段连续性更新次数 - 但在1.x版中,我们在配置里则使用了
discriminator_steps
,generator_steps
和accumulative_counts
。discriminator_steps
和generator_steps
指的是自身在更新其它模型之前的一段连续性的更新次数 以BigGAN-128配置为例。
0.x版 | 1.x版 |
---|---|
|
|
最后更新:
November 27, 2023
创建日期: November 27, 2023
创建日期: November 27, 2023