@@ -66,10 +66,10 @@ class Optim:
6666 total_training_steps : int = - 1 # ! DO NOT SET, use trainer.total_steps
6767 betas : List [float ] = field (default_factory = lambda : [0.9 , 0.999 ])
6868 clip_grad : float = 1.0
69- lr_warmup_init : float = 0.0
69+ lr_warmup_init : Optional [ float ] = None # 0.0
7070 lr_decay_steps : Optional [int ] = None
71- lr_decay_style : str = "constant"
72- min_lr : float = 0.0
71+ lr_decay_style : Optional [ str ] = None # "constant"
72+ min_lr : Optional [ float ] = None # 0.0
7373 weight_decay : float = 0.01
7474 weight_decay_incr_style : str = "constant"
7575 lr_wsd_decay_style : str = "exponential"
@@ -607,22 +607,32 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
607607 self .critic .strategy = "fsdp"
608608
609609 # Algorithm related config
610- for field_name in config .algorithm .optimizer .__dataclass_fields__ :
611- field_value = getattr (config .algorithm .optimizer , field_name )
610+ actor_optim = self .actor_rollout_ref .actor .optim
611+ critic_optim = self .critic .optim
612+ optim_config = config .algorithm .optimizer
613+ for field_name in optim_config .__dataclass_fields__ :
614+ field_value = getattr (optim_config , field_name )
612615 if field_name == "optimizer_type" :
613- setattr (self .actor_rollout_ref .actor .optim , "optimizer" , field_value )
614- elif hasattr (self .actor_rollout_ref .actor .optim , field_name ):
615- setattr (self .actor_rollout_ref .actor .optim , field_name , field_value )
616+ setattr (actor_optim , "optimizer" , field_value )
617+ elif hasattr (actor_optim , field_name ):
618+ setattr (actor_optim , field_name , field_value )
619+ # ensure megatron optimizer config compatibility
620+ set_if_none (actor_optim , "lr_warmup_init" , optim_config .min_lr_ratio * optim_config .lr )
621+ set_if_none (actor_optim , "lr_decay_steps" , self .trainer .total_training_steps )
622+ set_if_none (actor_optim , "lr_decay_style" , optim_config .lr_scheduler_type )
623+ set_if_none (actor_optim , "min_lr" , optim_config .min_lr_ratio * optim_config .lr )
624+ set_if_none (critic_optim , "lr_warmup_init" , 0.0 )
625+ set_if_none (critic_optim , "lr_decay_steps" , self .trainer .total_training_steps )
626+ set_if_none (critic_optim , "lr_decay_style" , "constant" )
627+ set_if_none (critic_optim , "min_lr" , 0.0 )
616628 # fix optimizer type for fsdp
617629 if config .trainer .trainer_strategy .startswith ("fsdp" ):
618630 optim_map = {
619631 "adam" : "AdamW" ,
620632 "adamw" : "AdamW" ,
621633 "sgd" : "SGD" ,
622634 }
623- actor_optim = self .actor_rollout_ref .actor .optim
624635 actor_optim .optimizer = optim_map .get (actor_optim .optimizer , actor_optim .optimizer )
625- critic_optim = self .critic .optim
626636 critic_optim .optimizer = optim_map .get (critic_optim .optimizer , critic_optim .optimizer )
627637 self .actor_rollout_ref .actor .use_kl_loss = config .algorithm .kl_loss_fn != "none"
628638 self .algorithm .use_kl_in_reward = config .algorithm .kl_penalty_fn != "none"
0 commit comments