-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
System Info
Environment matches verl requirements.
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
1. 我遇到的问题现象 (Phenomenon)
我在跑训练时遇到了一个 “训练未完成就提前自动断开” 的问题。具体场景是:我配置的是 Step 模式(total_training_steps = 100),因为数据集不大,DataLoader 长度只有 25,配置文件里 total_epochs 用的是默认值 1。
故障复现步骤:
- 配置:设置
total_training_steps = 100,total_epochs = 1(保持默认)。 - 数据:使用较小的数据集,
len(train_dataloader) = 25(小于总步数)。 - 操作:从
global_step = 20的 Checkpoint 处恢复训练 (Resume)。 - 结果:训练器启动后,仅运行了 5 个 step(从 step 20 跑到 25,即当前 Epoch 结束),然后直接停止并退出。
- 关键痛点:训练停止时没有任何日志输出或报错提示(Silent Shutdown),难以排查是因为步数不够还是其他原因导致的退出。
2. 原因排查 (Root Cause Analysis)
我看了一下 dapo_ray_trainer.py 的 fit 函数源码,发现逻辑存在冲突:
-
主要原因:外层循环被 Epoch 锁死
代码里的主循环写的是for epoch in range(config.total_epochs)。在我这个场景下,epoch 上限是 1。
因为我是 Resume 接着跑,Epoch 0 剩下的数据只有 5 个 batch。程序跑完这 5 个 batch 后,Dataloader 耗尽,外层range(1)也耗尽,导致fit函数直接 return,触发了 Ray 的退出逻辑。
本质问题是: 代码没有根据我设定的total_training_steps自动计算需要循环多少个 Epoch 才能喂饱这些步数。 -
次要发现:步数判定不一致
136 行附近的is_last_step判断用的是gen_steps。在 PPO 这类算法里,如果采样步数和更新步数不一致,这个判断逻辑可能会导致 Checkpoint 保存时机不对。
3. 修改建议 (Proposed Solution)
建议在 fit 函数开头加一个动态计算逻辑,不要强依赖配置里的 total_epochs。
修改位置: dapo_ray_trainer.py
建议改法 1:解耦 Loop 限制 (Line ~110)
import math
# [原逻辑]:只读配置,容易不够跑
# for epoch in range(self.config.trainer.total_epochs):
# [建议逻辑]:根据目标 Step 自动算需要跑几轮数据
needed_epochs = math.ceil(self.total_training_steps / len(self.train_dataloader))
actual_epochs = max(self.config.trainer.total_epochs, needed_epochs)
for epoch in range(actual_epochs):
# ...建议改法 2:统一时钟 (Line ~136)
Python
# [原逻辑]:
# is_last_step = self.gen_steps >= self.total_training_steps
# [建议逻辑]:统一用 global_steps (更新步数) 判定结束
is_last_step = self.global_steps >= self.total_training_steps```
### Expected behavior
在配置为 Step-based 模式时,训练器应该自动循环数据(开启新的 Epoch),直到 global_steps 达到设定的 100 步为止。即使从 step 20 恢复且当前 Epoch 仅剩 5 个 batch,程序也应自动衔接到下一个 Epoch 继续训练,而不是直接退出。
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working