优化器的迁移⚓︎
我们已经将MMGeneration 1.x合并至MMagic。以下是针对MMGeneration中优化器的迁移事项。
在0.x版中,MMGeneration使用PyTorch自带的优化器,其只提供了通用参数优化,而在1.x版中,我们则使用了MMEngine提供的OptimizerWrapper。
对比PyTorch自带的Optimizer,OptimizerWrapper可以支持如下功能:
OptimizerWrapper.update_params在一个单一的函数中就实现了zero_grad,backward和step- 支持梯度自动累积
- 提供一个名为
OptimizerWrapper.optim_context的上下文管理器来封装前向进程,optim_context会根据当前更新迭代数目来自动调用torch.no_sync,在AMP(Auto Mixed Precision)训练中,autocast也会在optim_context中被调用。
对GAN模型,生成器和鉴别器采用不同的优化器和训练策略。要使GAN模型的train_step函数签名和其它模型的保持一致,我们使用从OptimizerWrapper继承下来的OptimWrapperDict来封装生成器和鉴别器的优化器,为了便于该流程的自动化MMagic实现了MultiOptimWrapperContructor构造器。如你想训练GAN模型,那么应该在你的配置中指定该构造器。
如下是0.x版和1.x版的配置对比
| 0.x版 | 1.x版 |
|---|---|
|
|
注意,在1.x版中,MMGeneration使用
OptimWrapper来实现梯度累加,这就会导致在0.x版和1.x版之间,discriminator_steps配置(用于在多次更新鉴别器之后更新一次生成器的训练技巧)与梯度累加均出现不一致问题。
- 在0.x版中,我们在配置里使用
disc_steps,gen_steps和batch_accumulation_steps。disc_steps和batch_accumulation_steps会根据train_step的调用次数来进行统计(亦即dataloader中数据的读取次数)。因此鉴别器的一段连续性更新次数为disc_steps // batch_accumulation_steps。且对于生成器,gen_steps是生成器实际的一段连续性更新次数 - 但在1.x版中,我们在配置里则使用了
discriminator_steps,generator_steps和accumulative_counts。discriminator_steps和generator_steps指的是自身在更新其它模型之前的一段连续性的更新次数 以BigGAN-128配置为例。
| 0.x版 | 1.x版 |
|---|---|
|
|
最后更新:
November 27, 2023
创建日期: November 27, 2023
创建日期: November 27, 2023