Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,28 @@ def fit(self):
with marked_timer("update_actor", timing_raw, color="red"):
actor_output = self._update_actor(batch)

# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
esi_close_to_expiration = should_save_ckpt_esi(
max_steps_duration=self.max_steps_duration,
redundant_time=self.config.trainer.esi_redundant_time,
)
# Check if the conditions for saving a checkpoint are met.
# The conditions include a mandatory condition (1) and
# one of the following optional conditions (2/3/4):
# 1. The save frequency is set to a positive value.
# 2. It's the last training step.
# 3. The current step number is a multiple of the save frequency.
# 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
if self.config.trainer.save_freq > 0 and (
is_last_step
or self.global_steps % self.config.trainer.save_freq == 0
or esi_close_to_expiration
):
if esi_close_to_expiration:
print("Force saving checkpoint: ESI instance expiration approaching.")
with marked_timer("save_checkpoint", timing_raw, color="green"):
self._save_checkpoint()

# update weights from trainer to rollout
with marked_timer("update_weights", timing_raw, color="red"):
self.checkpoint_manager.update_weights()
Expand All @@ -1630,30 +1652,6 @@ def fit(self):
last_val_metrics = val_metrics
metrics.update(val_metrics)

# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
esi_close_to_expiration = should_save_ckpt_esi(
max_steps_duration=self.max_steps_duration,
redundant_time=self.config.trainer.esi_redundant_time,
)
# Check if the conditions for saving a checkpoint are met.
# The conditions include a mandatory condition (1) and
# one of the following optional conditions (2/3/4):
# 1. The save frequency is set to a positive value.
# 2. It's the last training step.
# 3. The current step number is a multiple of the save frequency.
# 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
):
if esi_close_to_expiration:
print("Force saving checkpoint: ESI instance expiration approaching.")
with marked_timer("save_checkpoint", timing_raw, color="green"):
# sleep replicas to avoid OOM during checkpoint saving
self.checkpoint_manager.sleep_replicas()
self._save_checkpoint()
# wake replicas to avoid OOM during checkpoint saving
self.checkpoint_manager.update_weights()

with marked_timer("stop_profile", timing_raw):
next_step_profile = (
self.global_steps + 1 in self.config.global_profiler.steps
Expand Down
Loading