diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 46271507934..66966d749ec 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1606,6 +1606,28 @@ def fit(self): with marked_timer("update_actor", timing_raw, color="red"): actor_output = self._update_actor(batch) + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 + or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + # update weights from trainer to rollout with marked_timer("update_weights", timing_raw, color="red"): self.checkpoint_manager.update_weights() @@ -1630,30 +1652,6 @@ def fit(self): last_val_metrics = val_metrics metrics.update(val_metrics) - # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. - esi_close_to_expiration = should_save_ckpt_esi( - max_steps_duration=self.max_steps_duration, - redundant_time=self.config.trainer.esi_redundant_time, - ) - # Check if the conditions for saving a checkpoint are met. - # The conditions include a mandatory condition (1) and - # one of the following optional conditions (2/3/4): - # 1. The save frequency is set to a positive value. - # 2. It's the last training step. - # 3. The current step number is a multiple of the save frequency. - # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration - ): - if esi_close_to_expiration: - print("Force saving checkpoint: ESI instance expiration approaching.") - with marked_timer("save_checkpoint", timing_raw, color="green"): - # sleep replicas to avoid OOM during checkpoint saving - self.checkpoint_manager.sleep_replicas() - self._save_checkpoint() - # wake replicas to avoid OOM during checkpoint saving - self.checkpoint_manager.update_weights() - with marked_timer("stop_profile", timing_raw): next_step_profile = ( self.global_steps + 1 in self.config.global_profiler.steps