-
Notifications
You must be signed in to change notification settings - Fork 132
feat: add async RL support #1098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
…1 ppo step Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Youngeun Kwon <[email protected]>
…1035) Signed-off-by: Rahul Chand <[email protected]> Signed-off-by: Youngeun Kwon <[email protected]> Co-authored-by: Rahul Chand <[email protected]> Co-authored-by: Youngeun Kwon <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome work @parthchadha . Do you think you could add maybe a sequence diagram like this https://docs.mermaidchart.com/mermaid-oss/syntax/sequenceDiagram.html#actors somewhere in the docs to help understand the flow of data/stalling? still reviewing, but this is an initial pass
also, does this close #600 ? if so, could you add it to the PR description so it'll auto close when completed?
print(f" - train_global_batch_size: {train_gbs}") | ||
print(f" - min_trajectories_needed: {min_trajectories_needed} (async mode)") | ||
|
||
_replay_py_exec = get_actor_python_env( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did using RayWorkerGroup not work for this use-case? too much boilerplate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it didn't work, don't recall the error but will try again and update the code if it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i wonder if it's due to the same issue @RayenTian faced in the reward model pr:

nemo_rl/algorithms/grpo.py
Outdated
# Clean up | ||
print("🛑 Stopping trajectory collection...") | ||
try: | ||
ray.get(trajectory_collector.stop.remote()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you call stop and it set()s all the threading events, but kill the actor right after, it seems like the amount of time elapsed is negligible, so the ray.kill following this may still be "disruptive". so is this "stop" even needed?
i can see this being useful for unit tests though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, can delete stop and just kill because anyway we are not storing the trajectories in a serialized form for further use right now so its fine to just kill the trajectory collector.
nemo_rl/algorithms/async_utils.py
Outdated
self.max_size = max_size | ||
self.trajectories = [] | ||
self.trajectory_versions = [] # weight-version used for generation | ||
self.target_weight_versions = [] # weight-version this trajectory is targeted for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a comment somewhere to help disambiguate these two versions?
took a stab: sequenceDiagram
autonumber
participant Trainer
participant AsyncCollector as AsyncTrajectoryCollector
participant Dataloader
participant Policy as PolicyGeneration
participant Envs as Environments
participant Buffer as ReplayBuffer
Trainer->>AsyncCollector: set_weight_version(t)
Note right of AsyncCollector: g = current generation_weight_version = t
par Continuous collection
AsyncCollector->>Dataloader: next batch
AsyncCollector->>AsyncCollector: target_weights = _calculate_target_weights(g)
AsyncCollector->>Buffer: get_last_target_weight_already_generated()
Buffer-->>AsyncCollector: last_tgt
AsyncCollector->>AsyncCollector: tgt = _get_next_target_for_generation(g)\n(reserve if tgt > last_tgt and not in-flight)
alt tgt found
loop for each prompt in batch
AsyncCollector->>Policy: run_async_multi_turn_rollout(repeated prompt)
Policy->>Envs: step through rollout
Envs-->>Policy: transitions
Policy-->>AsyncCollector: trajectory_group
AsyncCollector->>Buffer: push_with_wait_signal(group, weight_version=g, target_weight_version=tgt)
Buffer->>Buffer: trajectories += group
Buffer->>Buffer: trajectory_versions += g
Buffer->>Buffer: target_weight_versions += tgt
Buffer->>Buffer: last_target_weight_already_generated = max(last, tgt)
Buffer-->>AsyncCollector: "success" | "full"
end
else no tgt available
AsyncCollector-->>AsyncCollector: pause/wait (all targets covered or reserved)
end
and Training
loop until enough groups for step t
Trainer->>Buffer: sample(num_groups, current_weight_version=t, max_age)
Buffer->>Buffer: min_valid = max(0, t - max_age)
Buffer->>Buffer: valid_indices: min_valid <= generation g <= t
Buffer->>Buffer: intended_indices: target_weight_versions == t
alt enough intended groups
Buffer-->>Trainer: trajectories + avg_trajectory_age
Buffer->>Buffer: remove selected entries
else insufficient
Buffer-->>Trainer: None (stall until more for target t)
end
end
end
Note over AsyncCollector: _calculate_target_weights(g)\n- If g == initial_weight_version:\n [initial, ..., initial+max_age]\n (include current step at cold start)\n- Else: [g+1, ..., g+max_age]\n (only future targets once warm)
|
Signed-off-by: Parth Chadha <[email protected]>
WalkthroughAdds an optional asynchronous GRPO training path. New async utilities (ReplayBuffer, AsyncTrajectoryCollector) with Ray actors manage background trajectory collection. GRPO gains async_grpo_train and memory cleanup. Entry script toggles async vs sync based on config. Configs gain grpo.async_grpo fields. Ray actor environment registry updated for new actors. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Runner as run_grpo_math.py
participant GRPO as nemo_rl.algorithms.grpo
Note over Runner,GRPO: Training entry
User->>Runner: Launch with config
Runner->>Runner: Read grpo.async_grpo.enabled
alt async enabled
Runner->>GRPO: async_grpo_train(..., grpo_save_state, max_trajectory_age_steps)
else sync
Runner->>GRPO: grpo_train(...)
end
sequenceDiagram
autonumber
participant Trainer as async_grpo_train
participant RB as ReplayBuffer (Ray)
participant Collector as AsyncTrajectoryCollector (Ray)
participant Policy as Policy/Generation
participant Env as Envs
participant Val as Validator
Note over Trainer: Initialization
Trainer->>RB: start actor
Trainer->>Collector: start actor (policy, tokenizer, envs)
Trainer->>Collector: set_weight_version(v0)
loop training steps
Note over Collector,Env: Background rollout
Collector->>Env: run_async_multi_turn_rollout
Env-->>Collector: trajectories
Collector->>RB: push_with_wait_signal(groups, gen_w, target_w)
Note over Trainer,RB: Sampling
Trainer->>RB: sample(num_groups, current_w, max_age)
RB-->>Trainer: groups or None
alt got sample
Trainer->>Policy: logprob/inference
Policy-->>Trainer: scores
Trainer->>Trainer: compute baselines/advantages, optimize
opt optional refit
Trainer->>Collector: pause/prepare_for_refit
Trainer->>Policy: update weights (version++)
Trainer->>Collector: set_weight_version/new, resume_after_refit
end
else stall
Trainer->>Trainer: wait/backoff
end
opt periodic validation
Trainer->>Collector: pause
Trainer->>Val: run validation
Val-->>Trainer: metrics
Trainer->>Trainer: gc.collect(), torch.cuda.empty_cache()
Trainer->>Collector: resume
end
opt checkpoint
Trainer->>Trainer: save state
end
end
Note over Trainer,RB: Cleanup on exit
Trainer->>Collector: stop/cleanup
Trainer->>RB: clear/stop
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (1)
nemo_rl/algorithms/grpo.py (1)
1127-1129
: Fix: incorrect check for importance sampling;loss_fn
is not subscriptableThis asserts into the callable loss object; use the loss config instead.
Apply:
- assert loss_fn["use_importance_sampling_correction"] is True, ( + assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, ( "Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!" )
🧹 Nitpick comments (22)
examples/configs/async_grpo_math_1B.yaml (4)
7-7
: Fix stale comment for trajectory age.Comment says “last 4 training steps” but value is 1. Align the comment with the value.
- max_trajectory_age_steps: 1 # Allow trajectories from the last 4 training steps + max_trajectory_age_steps: 1 # Allow trajectories from the last 1 training step
34-41
: Normalize boolean casing for consistency.Mix of
True/False
andtrue/false
. Prefer one style (repo convention) for readability.- cpu_offload: False + cpu_offload: false - activation_checkpointing: false + activation_checkpointing: false - enabled: True + enabled: true - enabled: False + enabled: false - enforce_eager: False + enforce_eager: falseAlso applies to: 43-47, 58-66, 70-70, 88-95, 96-109
58-64
: Optional: double-check vLLM lengths.
max_new_tokens
equalsmax_total_sequence_length
and vLLMmax_model_len
equals the same. Depending on prompt length, this can cap generations early or waste headroom. Consider settingmax_new_tokens = max_total_sequence_length - max_input_seq_length
(or a fixed cap) if that matches your runtime assumptions.Also applies to: 49-56
82-82
: Clean up trailing spaces and add newline at EOF.Yamllint flagged trailing spaces and missing newline.
-# Environment configuration +# Environment configuration @@ - flush_interval: 10 -\ No newline at end of file + flush_interval: 10 +Also applies to: 108-109
examples/configs/async_grpo_math_8B.yaml (3)
21-21
: Clarify generation batch-size note.Comment says “Only used when generating using HF backend” while backend is vLLM. Either drop the note or indicate it’s ignored with vLLM to avoid confusion.
- generation_batch_size: 32 # Only used when generating using HF backend + generation_batch_size: 32 # Ignored when using vLLM backendAlso applies to: 60-76
27-35
: Normalize boolean casing for consistency.Unify
True/False
to repo-preferred casing.- cpu_offload: False + cpu_offload: false - dynamic_batching: - enabled: False + dynamic_batching: + enabled: false - enforce_eager: False + enforce_eager: falseAlso applies to: 36-38, 69-75
59-59
: Remove trailing spaces and add newline at EOF.Yamllint warnings.
- + @@ - flush_interval: 10 -\ No newline at end of file + flush_interval: 10 +Also applies to: 97-98
nemo_rl/distributed/ray_actor_environment_registry.py (1)
39-42
: ReplayBuffer likely doesn’t need vLLM environment.Unless it imports vLLM symbols directly, mapping ReplayBuffer to
PY_EXECUTABLES.VLLM
increases dependency surface and startup time. PreferPY_EXECUTABLES.BASE
orSYSTEM
. KeepAsyncTrajectoryCollector
under VLLM if it handles vLLM exceptions.- # ReplayBuffer needs vLLM environment to handle trajectory data from VllmGenerationWorker - "nemo_rl.algorithms.async_utils.ReplayBuffer": PY_EXECUTABLES.VLLM, + # ReplayBuffer is transport-only; avoid vLLM dependency bloat + "nemo_rl.algorithms.async_utils.ReplayBuffer": PY_EXECUTABLES.BASE,If exceptions from vLLM are serialized through Ray and require import on the receiver, keep it as-is; otherwise prefer BASE.
examples/run_grpo_math.py (2)
281-295
: Pass kwargs to grpo_train for signature stability.Synchronous path uses positional args; safer to pass by name to avoid breakage if the signature evolves.
- grpo_train( - policy, - policy_generation, - dataloader, - val_dataloader, - tokenizer, - loss_fn, - task_to_env, - val_task_to_env, - logger, - checkpointer, - grpo_state, - master_config, - ) + grpo_train( + policy=policy, + policy_generation=policy_generation, + dataloader=dataloader, + val_dataloader=val_dataloader, + tokenizer=tokenizer, + loss_fn=loss_fn, + task_to_env=task_to_env, + val_task_to_env=val_task_to_env, + logger=logger, + checkpointer=checkpointer, + grpo_save_state=grpo_state, + master_config=master_config, + )
225-233
: Guard for generation config presence is good; consider asserting IS correction for async.
async_grpo_train
asserts importance-sampling correction; you could proactively warn in the runner ifasync_grpo.enabled
and loss config disables it.nemo_rl/algorithms/grpo.py (8)
1160-1169
: Unused variable and redundant print
train_gbs
is never used; drop it. Also the duplicate “num_generations_per_prompt/samples_per_prompt_group” lines say the same thing.- train_gbs = master_config["policy"]["train_global_batch_size"] @@ - print(f" - train_global_batch_size: {train_gbs}")
1234-1234
: Remove unused assignmentVariable is never used.
- collection_task = trajectory_collector.start_collection.remote(dataloader) + trajectory_collector.start_collection.remote(dataloader)
1307-1316
:wait_iterations
never incrementsIf you keep the debug loop, increment or drop the counter.
- wait_iterations = 0 + wait_iterations = 0 @@ - # wait_iterations += 1 + wait_iterations += 1
1432-1436
: Assertion message lengthMinor: prefer a concise message or a custom exception to satisfy TRY003.
- raise AssertionError( - f"Configuration error: (num_prompts_per_step * num_generations_per_prompt) = {expected_batch_size} must be divisible by data_parallel size {dp_size}." - ) + raise AssertionError( + f"Train batch ({expected_batch_size}) must be divisible by DP size ({dp_size})." + )
1495-1511
: Rename unused loop var
j
unused.- for j, message in enumerate(message_log): + for _j, message in enumerate(message_log):
1606-1611
: Redundant import and GPU mem cleanup
gc
already imported at file top. Keep cleanup, drop local import.- import gc - gc.collect() torch.cuda.empty_cache()
1635-1640
: warnings.warn without stacklevelAdd stacklevel=2 for actionable locations.
- warnings.warn( + warnings.warn( f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " "Saving most recent k checkpoints instead." - ) + , stacklevel=2)
1441-1449
: Prompt-only advantages in async path vs. full-input in sync pathAsync computes baselines from prompt-only tokens (good). Sync path still uses all tokens; this may cause a behavior divergence.
Unify both paths to prompt-only baseline (or document why they differ).
Also applies to: 1501-1507
nemo_rl/algorithms/async_utils.py (4)
424-429
: Broad exception handlingCatching bare
Exception
repeatedly obscures failures. Consider narrowing (ValueError/RuntimeError) or re-raising after logging.If you must keep broad catches in actors, at least log the full stack and include target/prompt ids for triage (you already print tracebacks).
Also applies to: 489-494, 631-636, 636-641, 657-661
258-266
: Unused_pg_lock
Not used; remove to reduce noise.
- self._pg_lock: _threading.Lock = _threading.Lock()
95-99
:get_existing_target_weights
currently unusedKeep if you plan to expose telemetry; otherwise remove.
134-146
: Raising on “old trajectories” may crash long runsTurning this into a warning + purge is gentler, now that per-target quotas prevent accumulation.
- if old_trajectories: - raise ValueError( - f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}" - ) + if old_trajectories: + print(f"⚠️ Dropping {len(old_trajectories)} old trajectories (< {min_valid_version})") + # optional: actually purge them here
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/configs/async_grpo_math_1B.yaml
(1 hunks)examples/configs/async_grpo_math_8B.yaml
(1 hunks)examples/run_grpo_math.py
(1 hunks)nemo_rl/algorithms/async_utils.py
(1 hunks)nemo_rl/algorithms/grpo.py
(4 hunks)nemo_rl/distributed/ray_actor_environment_registry.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
PY_EXECUTABLES
(42-58)
nemo_rl/algorithms/async_utils.py (6)
nemo_rl/algorithms/grpo.py (1)
MasterConfig
(121-129)nemo_rl/data/interfaces.py (1)
DatumSpec
(32-40)nemo_rl/distributed/batched_data_dict.py (2)
BatchedDataDict
(75-839)repeat_interleave
(703-724)nemo_rl/environments/interfaces.py (1)
EnvironmentInterface
(52-88)nemo_rl/experience/rollouts.py (1)
run_async_multi_turn_rollout
(751-895)nemo_rl/models/generation/interfaces.py (1)
GenerationInterface
(208-242)
examples/run_grpo_math.py (1)
nemo_rl/algorithms/grpo.py (2)
async_grpo_train
(1090-1738)grpo_train
(513-972)
nemo_rl/algorithms/grpo.py (8)
nemo_rl/utils/timer.py (4)
time
(110-123)Timer
(22-248)get_timing_metrics
(196-233)reset
(235-248)nemo_rl/distributed/ray_actor_environment_registry.py (1)
get_actor_python_env
(47-62)nemo_rl/utils/venvs.py (1)
create_local_venv_on_each_node
(152-189)nemo_rl/utils/logger.py (2)
Logger
(710-933)log_batched_dict_as_jsonl
(804-828)nemo_rl/algorithms/async_utils.py (12)
AsyncTrajectoryCollector
(239-660)ReplayBuffer
(36-235)start_collection
(362-373)set_weight_version
(327-336)pause
(498-501)resume
(503-506)size
(225-228)sample
(100-223)get_debug_info
(82-89)prepare_for_refit
(508-526)resume_after_refit
(528-531)get_dataloader_state
(555-559)nemo_rl/distributed/batched_data_dict.py (4)
size
(793-802)BatchedDataDict
(75-839)from_batches
(102-151)to
(804-811)nemo_rl/data/llm_message_utils.py (1)
batched_message_log_to_flat_message
(233-390)nemo_rl/algorithms/utils.py (1)
calculate_baseline_and_std_per_prompt
(47-116)
🪛 YAMLlint (1.37.1)
examples/configs/async_grpo_math_1B.yaml
[error] 82-82: trailing spaces
(trailing-spaces)
[error] 108-108: trailing spaces
(trailing-spaces)
[error] 109-109: no new line character at the end of file
(new-line-at-end-of-file)
[error] 109-109: trailing spaces
(trailing-spaces)
examples/configs/async_grpo_math_8B.yaml
[error] 59-59: trailing spaces
(trailing-spaces)
[error] 97-97: trailing spaces
(trailing-spaces)
[error] 98-98: no new line character at the end of file
(new-line-at-end-of-file)
[error] 98-98: trailing spaces
(trailing-spaces)
🪛 Ruff (0.12.2)
nemo_rl/algorithms/async_utils.py
143-145: Avoid specifying long messages outside the exception class
(TRY003)
358-358: Consider moving this statement to an else
block
(TRY300)
359-359: Do not catch blind exception: Exception
(BLE001)
424-424: Do not catch blind exception: Exception
(BLE001)
489-489: Do not catch blind exception: Exception
(BLE001)
631-631: Do not catch blind exception: Exception
(BLE001)
636-636: Do not catch blind exception: Exception
(BLE001)
657-657: Do not catch blind exception: Exception
(BLE001)
nemo_rl/algorithms/grpo.py
1234-1234: Local variable collection_task
is assigned to but never used
Remove assignment to unused variable collection_task
(F841)
1252-1252: Do not catch blind exception: Exception
(BLE001)
1263-1263: Do not catch blind exception: Exception
(BLE001)
1291-1291: Do not catch blind exception: Exception
(BLE001)
1434-1436: Avoid specifying long messages outside the exception class
(TRY003)
1495-1495: Loop control variable j
not used within loop body
Rename unused j
to _j
(B007)
1635-1635: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
1730-1730: Do not catch blind exception: Exception
(BLE001)
1735-1735: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (13)
examples/configs/async_grpo_math_1B.yaml (1)
2-2
: Defaults resolution path verified:load_config_with_inheritance
wraps a stringdefaults
value into a list before merging parent configs, sodefaults: "grpo_math_1B.yaml"
is correctly applied.examples/configs/async_grpo_math_8B.yaml (2)
2-2
: Verify defaults resolution path (same as 1B config).Ensure
defaults: "grpo_math_1B.yaml"
is actually merged by your loader.Use the same script provided on the 1B config comment.
18-23
: Sanity-check effective batch vs hardware.
train_global_batch_size: 512
,train_micro_batch_size: 1
,gpus_per_node: 8
,num_nodes: 1
implies very high accumulation. Confirm divisibility and memory headroom at 4096 seq length; otherwise consider reducing global batch or enabling dynamic batching.Also applies to: 81-84
examples/run_grpo_math.py (1)
255-277
: LGTM: clean async toggle and call surface.Conditional path, clear print, and explicit kwargs into
async_grpo_train
look good.nemo_rl/algorithms/grpo.py (5)
1151-1153
: Colocated inference unsupported in async path — good to assertClear early failure mode. Consider improving the message with a hint to set policy.generation.colocated.enabled=false.
1273-1300
: Pause/resume around initial validation — good guardrailsPausing collection avoids pressure during validation. Nice.
1666-1667
: Verifyoffload_after_refit
at end of checkpointThis method is intended for refit. Using it post-checkpoint may evict state unexpectedly if the next step runs immediately.
Would you confirm it’s safe here for Megatron/vLLM? If not, consider removing it in the checkpoint path.
1083-1086
: GPU memory cleanup after validation — looks goodMakes OOMs less likely during long async runs.
1424-1436
: No action required:BatchedDataDict.size
is an @Property and calls to.size
(no parentheses) are correct.nemo_rl/algorithms/async_utils.py (4)
282-289
: Target selection and pause logic — LGTMDeterministic future targeting with a bounded age window and backpressure via last-target is sound given the per-target quota fix.
Also applies to: 306-325, 338-361, 392-414
438-471
:num_prompts = batch.size
— property vs methodSame note as main file: ensure
.size
is a property, or callbatch.size()
.See script in grpo.py comment.
561-566
: Thread set maintenance — goodCleanup avoids unbounded growth of
_inflight_threads
.
367-374
: Daemon collection thread and clean return log — LGTMActor lifecycle is clean; driver kills actors in finally.
@ray.remote | ||
class ReplayBuffer: | ||
"""Replay buffer storing per-prompt groups. | ||
|
||
A single entry corresponds to 1 prompt repeated by | ||
grpo.num_generations_per_prompt (required to compute per-prompt advantages). | ||
""" | ||
|
||
def __init__(self, max_size: int): | ||
self.max_size = max_size | ||
self.trajectories = [] | ||
# If trajectory_version is 1 and target_weight_version is 4 it means that weight version 1 was used for generating a trajectory and this trajectory will be used for training when weight version is 4. | ||
self.trajectory_versions = [] # it is the weight-version used for generation of a trajectory | ||
self.target_weight_versions = [] # it is the weight-version of the trainer where this trajectory will be used. | ||
|
||
self.last_target_weight_already_generated = -1 | ||
self._lock = _threading.Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add per-target quotas and late-advance semantics to ReplayBuffer
Enforce exactly groups_per_target
groups per target, and only mark last_target_weight_already_generated
when the quota is met. Prevents overfill/underfill and avoids stale leftovers that later trigger the “old trajectory” error.
Apply:
- class ReplayBuffer:
+ class ReplayBuffer:
@@
- def __init__(self, max_size: int):
+ def __init__(self, max_size: int, groups_per_target: int):
self.max_size = max_size
+ self.groups_per_target = groups_per_target
self.trajectories = []
@@
self.last_target_weight_already_generated = -1
self._lock = _threading.Lock()
+ self._target_counts: dict[int, int] = {}
@@
- ) -> str:
+ ) -> str:
@@
- self.trajectories.append(trajectory)
- self.trajectory_versions.append(weight_version)
- self.target_weight_versions.append(target_weight_version)
- self.last_target_weight_already_generated = max(
- self.last_target_weight_already_generated, target_weight_version
- )
+ current = self._target_counts.get(target_weight_version, 0)
+ if current >= self.groups_per_target:
+ return "redundant"
+ self.trajectories.append(trajectory)
+ self.trajectory_versions.append(weight_version)
+ self.target_weight_versions.append(target_weight_version)
+ self._target_counts[target_weight_version] = current + 1
+ # Advance only when quota is satisfied
+ if self._target_counts[target_weight_version] >= self.groups_per_target:
+ self.last_target_weight_already_generated = max(
+ self.last_target_weight_already_generated, target_weight_version
+ )
@@
- sampled_items = [self.trajectories[i] for i in selected]
+ sampled_items = [self.trajectories[i] for i in selected]
@@
for idx in sorted(selected, reverse=True):
- self.trajectory_versions.pop(idx)
- self.target_weight_versions.pop(idx)
- self.trajectories.pop(idx)
+ tw = self.target_weight_versions.pop(idx)
+ self.trajectory_versions.pop(idx)
+ self.trajectories.pop(idx)
+ self._target_counts[tw] = self._target_counts.get(tw, 1) - 1
+ if self._target_counts[tw] <= 0:
+ del self._target_counts[tw]
@@
with self._lock:
return len(self.trajectories)
Additionally, return "complete"
when quota is met:
- return "success"
+ return "complete" if self._target_counts[target_weight_version] >= self.groups_per_target else "success"
Also applies to: 82-90, 95-99, 100-223, 225-236
f"📦 Buffered per-prompt group (prompt_idx {prompt_idx}, target_weight {target_weight_version})" | ||
) | ||
|
||
# Release reservation when FIRST prompt group for this target is successfully buffered | ||
if prompt_idx == 0: | ||
with self._generation_check_lock: | ||
if target_weight_version in self._generating_targets: | ||
self._generating_targets.discard( | ||
target_weight_version | ||
) | ||
print( | ||
f"🧹 Released reservation for target weight {target_weight_version} (first prompt buffered)" | ||
) | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Reservation released on first push — can stall training on partial failures
Releasing _generating_targets
at prompt_idx == 0
advances scheduling even if not all groups are buffered. If some threads fail, training will wait forever for the missing groups, and the collector won’t regenerate them.
Fix by making ReplayBuffer enforce per-target counts and only “advance” a target after groups_per_target
pushes. Have push_with_wait_signal
return "complete"
when it reaches the quota; release the reservation only on "complete"
.
- if status == "success":
+ if status in ("success", "complete"):
print(
f"📦 Buffered per-prompt group (prompt_idx {prompt_idx}, target_weight {target_weight_version})"
)
-
- # Release reservation when FIRST prompt group for this target is successfully buffered
- if prompt_idx == 0:
- with self._generation_check_lock:
- if target_weight_version in self._generating_targets:
- self._generating_targets.discard(
- target_weight_version
- )
- print(
- f"🧹 Released reservation for target weight {target_weight_version} (first prompt buffered)"
- )
+ if status == "complete":
+ with self._generation_check_lock:
+ if target_weight_version in self._generating_targets:
+ self._generating_targets.discard(target_weight_version)
+ print(f"🧹 Released reservation for target weight {target_weight_version} (all prompts buffered)")
break
See ReplayBuffer changes below.
Also applies to: 642-649
🤖 Prompt for AI Agents
In nemo_rl/algorithms/async_utils.py around lines 610-623 (and similarly
642-649), the code releases the _generating_targets reservation as soon as
prompt_idx == 0 which can advance scheduling prematurely and stall training if
other groups never buffer; change the logic so ReplayBuffer enforces per-target
group counts and push_with_wait_signal returns "complete" only when
groups_per_target pushes for that target have been received, then only discard
target_weight_version from _generating_targets when push_with_wait_signal
returns "complete" (i.e., replace the prompt_idx==0 check with checking the
return value and release reservation only on "complete").
nemo_rl/algorithms/grpo.py
Outdated
step = grpo_save_state["step"] | ||
weight_version = step # Tracks refitted weight versions | ||
consumed_samples = grpo_save_state["consumed_samples"] | ||
val_period = master_config["grpo"]["val_period"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix: inconsistent training state keys; grpo_save_state['step']
will KeyError
Save-state schema uses current_step/total_steps. Unify reads/writes to avoid crashes and incompatible checkpoints.
Apply:
- step = grpo_save_state["step"]
+ step = grpo_save_state.get("current_step", 0)
@@
- grpo_save_state["step"] = step + 1
+ grpo_save_state["current_step"] = step + 1
Also applies to: 1623-1629
# Calculate optimal buffer size based on generation limits to prevent length bias | ||
# Each weight version generates exactly num_prompts_per_step trajectories | ||
# With max_age_steps, we keep trajectories from multiple weight versions | ||
num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] | ||
late_arrival_slack = 2 | ||
optimal_buffer_size = ( | ||
num_prompts_per_step * max_trajectory_age_steps * late_arrival_slack | ||
) | ||
|
||
replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( | ||
max_size=optimal_buffer_size | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
ReplayBuffer ctor lacks per-target capacity; risk of over/under-fill and staleness
Without per-target quotas, a partial or duplicated target can permanently stall sampling or create stale leftovers. Pass groups-per-target and enforce it in the actor (see async_utils diff).
Apply here after updating ReplayBuffer (see other file):
- replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
- max_size=optimal_buffer_size
- )
+ replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
+ max_size=optimal_buffer_size,
+ groups_per_target=num_prompts_per_step,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Calculate optimal buffer size based on generation limits to prevent length bias | |
# Each weight version generates exactly num_prompts_per_step trajectories | |
# With max_age_steps, we keep trajectories from multiple weight versions | |
num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] | |
late_arrival_slack = 2 | |
optimal_buffer_size = ( | |
num_prompts_per_step * max_trajectory_age_steps * late_arrival_slack | |
) | |
replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( | |
max_size=optimal_buffer_size | |
) | |
replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( | |
max_size=optimal_buffer_size, | |
groups_per_target=num_prompts_per_step, | |
) |
🤖 Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 1190 to 1202, the ReplayBuffer is
instantiated without a per-target quota which can lead to over/under-fill and
stale entries; add a groups_per_target parameter to the actor creation call
using the GRPO config (e.g. groups_per_target =
master_config["grpo"]["groups_per_target"] or computed from the GRPO settings)
and pass it into ReplayBuffer.options(...).remote(max_size=optimal_buffer_size,
groups_per_target=groups_per_target) so the actor can enforce per-target
capacity (the actor-side enforcement is handled in the async_utils update).
try: | ||
# Run rollout for this prompt group | ||
# Async engine supports concurrent generation; avoid locking | ||
final_batch, rollout_metrics = run_async_multi_turn_rollout( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wdyt about timing this block and aggregating? i'm wondering if it's possible to measure the "generation bubble". Since with enough generation servers it seems like you could have lots of idle time maybe signaling you have too many generation servers or too low max_trajectory_age_steps
# Async-specific settings | ||
async_grpo: | ||
enabled: true # Enable async training | ||
max_trajectory_age_steps: 1 # Allow trajectories from the last 1 training steps | ||
|
||
grpo: | ||
num_prompts_per_step: 64 | ||
num_generations_per_prompt: 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was there an advantage to maintaining another config? wdyt about merging this config with grpo_math_1B.yaml and moving the async args under the algo root (default to false)? then putting these two recipes into examples/configs/recipes/llm/
& tests/test_suites/llm/
so we can track these nightly?
# Async-specific settings | |
async_grpo: | |
enabled: true # Enable async training | |
max_trajectory_age_steps: 1 # Allow trajectories from the last 1 training steps | |
grpo: | |
num_prompts_per_step: 64 | |
num_generations_per_prompt: 32 | |
grpo: | |
num_prompts_per_step: 64 | |
num_generations_per_prompt: 32 | |
# Async-specific settings | |
async_cfg: | |
enabled: true # Enable async training | |
max_trajectory_age_steps: 1 # Allow trajectories from the last 1 training steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, moved under grpo and removed async config examples.
nemo_rl/algorithms/grpo.py
Outdated
|
||
# Wait for initial buffer fill | ||
print( | ||
f"⏳ Waiting for replay buffer to have sufficient trajectories (min={min_trajectories_needed})..." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"⏳ Waiting for replay buffer to have sufficient trajectories (min={min_trajectories_needed})..." | |
f"⏳ Waiting for replay buffer to have sufficient trajectories ({min_trajectories_needed=})..." |
to avoid being conflated with minutes
nemo_rl/algorithms/grpo.py
Outdated
buffer_size_current = ray.get(replay_buffer.size.remote()) | ||
|
||
print( | ||
f" Wait iteration {wait_iterations}: buffer_size={buffer_size_current}/{min_trajectories_needed}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f" Wait iteration {wait_iterations}: buffer_size={buffer_size_current}/{min_trajectories_needed}" | |
f" Wait iteration {wait_iterations}: buffer_filled_ratio={buffer_size_current}/{min_trajectories_needed}" |
nemo_rl/algorithms/grpo.py
Outdated
if buffer_size_current >= min_trajectories_needed: | ||
break | ||
|
||
# wait_iterations += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing this comment is intentional. Do you mind leaving a comment if so?
f" Trajectory versions in buffer: {buffer_debug['trajectory_versions']}" | ||
) | ||
|
||
time.sleep(0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible the prints above this could get spammy if it is stuck in this loop for a while?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does but in practice I found these to be useful and not too spammy. We can revisit later if we want to remove these.
nemo_rl/algorithms/async_utils.py
Outdated
total_trajectories = len(self.trajectories) | ||
print("🔍 ReplayBuffer sampling debug:") | ||
print( | ||
f" current_weight_version={current_weight_version}, max_age_steps={max_age_steps}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f" current_weight_version={current_weight_version}, max_age_steps={max_age_steps}" | |
f" {current_weight_version=}, {max_age_steps=}" |
nemo_rl/algorithms/async_utils.py
Outdated
# Compute minimum valid version based on age window | ||
# max_age_steps=1 means trajectories from the last 1 step are valid | ||
min_valid_version = max(0, current_weight_version - max_age_steps) | ||
print(f" min_valid_version={min_valid_version}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print(f" min_valid_version={min_valid_version}") | |
print(f" {min_valid_version=}") |
save a few characters :)
nemo_rl/algorithms/async_utils.py
Outdated
self._generating_targets: set[int] = set() | ||
|
||
def _calculate_target_weights(self, generation_weight_version: int) -> list[int]: | ||
"""Calculate target weight versions for given generation weight version.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to workshop this, but from our offline discussion:
"""Calculate target weight versions for given generation weight version.""" | |
"""Calculate target weight versions for given generation weight version. | |
The list of versions returned enumerate the possible version a generation | |
server can target. These versions are looped over to see what training | |
step they can target. If all target versions are exhausted, this generation | |
server will remain idle until the next weight update. | |
Example: | |
generation_weight_version = 10 | |
max_trajectory_age_steps = 4 | |
Returns: | |
[11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14 | |
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just an FYI, the sync grpo loop has multi-epoch support. keep that for a future PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will add in a separate PR.
# Import async utilities only when needed | ||
from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer | ||
|
||
timer = Timer() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the next PR, don't want to add more core changes to algo before this version is checked in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@parthchadha were you planning on adding:
- a doc page
- unit tests
- function tests (ex: https://github.com/NVIDIA-NeMo/RL/blob/main/tests/functional/L1_Functional_Tests_GPU.sh)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/algorithms/grpo.py (1)
986-992
: Fix config key: use grpo.val_period in validate()
master_config["dpo"]["val_period"]
is wrong for GRPO and can assert incorrectly.- if val_dataloader is None: - assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + if val_dataloader is None: + assert val_dataloader is not None or master_config["grpo"]["val_period"] == 0, ( "val_dataloader is None, so dpo.val_period must be 0" )
♻️ Duplicate comments (4)
nemo_rl/algorithms/grpo.py (4)
1195-1206
: Pass per‑target quota to ReplayBuffer to avoid stalls and stalenessAdd
groups_per_target=num_prompts_per_step
so the actor can enforce per-target capacity.- replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( - max_size=optimal_buffer_size - ) + replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( + max_size=optimal_buffer_size, + groups_per_target=num_prompts_per_step, + )
1308-1321
: Clarify wait message and increment counter; throttle logs
- “min=” can be misread as minutes.
wait_iterations
never increments.- Prevent spam by logging every N iterations.
- print( - f"⏳ Waiting for replay buffer to have sufficient trajectories (min={min_trajectories_needed})..." - ) + print( + f"⏳ Waiting for replay buffer to have sufficient trajectories (min_trajectories_needed={min_trajectories_needed})..." + ) wait_iterations = 0 while True: buffer_size_current = ray.get(replay_buffer.size.remote()) - - print( - f" Wait iteration {wait_iterations}: buffer_size={buffer_size_current}/{min_trajectories_needed}" - ) + if wait_iterations % 10 == 0: + print( + f" Wait iteration {wait_iterations}: buffer_filled_ratio={buffer_size_current}/{min_trajectories_needed}" + ) + wait_iterations += 1
1138-1145
: Add timeout-based checkpointing like sync loopPrevents losing progress when a save deadline is near.
timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations()- # Checkpointing (same as sync version) - consumed_samples += master_config["grpo"]["num_prompts_per_step"] - if master_config["checkpointing"]["enabled"] and ( - is_last_step - or (step + 1) % master_config["checkpointing"]["save_period"] == 0 - ): + # Checkpointing (+ timeout) + consumed_samples += master_config["grpo"]["num_prompts_per_step"] + timeout.mark_iteration() + should_save_by_timeout = timeout.check_save() + if master_config["checkpointing"]["enabled"] and ( + is_last_step + or (step + 1) % master_config["checkpointing"]["save_period"] == 0 + or should_save_by_timeout + ):- timer.reset() - step += 1 + timer.reset() + step += 1 + if 'should_save_by_timeout' in locals() and should_save_by_timeout: + breakAlso applies to: 1621-1626, 1727-1730
1148-1152
: Unify save‑state schema:current_step
vsstep
(prevents KeyError and incompatible checkpoints)GRPOSaveState defines
current_step/total_steps
. Accessing/settingstep
will break resume.- step = grpo_save_state["step"] + step = grpo_save_state.get("current_step", 0)- grpo_save_state["step"] = step + 1 + grpo_save_state["current_step"] = step + 1 + grpo_save_state["total_steps"] = grpo_save_state.get("total_steps", 0) + 1Also applies to: 1628-1634
🧹 Nitpick comments (6)
nemo_rl/algorithms/grpo.py (6)
687-697
: Avoid KeyError for optional overlong_filteringField is NotRequired; default to False.
- use_overlong_filtering = master_config["grpo"]["overlong_filtering"] + use_overlong_filtering = master_config["grpo"].get("overlong_filtering", False)
1238-1240
: Remove unused variable (Ruff F841) or make intent explicitAssignment isn’t used; keep a throwaway to signal intentional fire‑and‑forget.
- collection_task = trajectory_collector.start_collection.remote(dataloader) + _ = trajectory_collector.start_collection.remote(dataloader)
1499-1516
: Rename unused loop varj
(Ruff B007)No functional change.
- for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): + for i, message_log in enumerate(repeated_batch["message_log"]): + for _j, message in enumerate(message_log):
1407-1444
: Assertion message is long (TRY003); optional tidyMinor: shorten or rely on the default message to satisfy linters.
1281-1305
: Broad exception catches; prefer narrower or re‑raise with contextCatch specific Ray/Runtime errors; avoid swallowing unrelated failures.
1088-1091
: GPU memory cleanup: optionally gate on CUDA availabilityHarmless as-is; consider guarding to avoid surprises on CPU-only runs.
- gc.collect() - torch.cuda.empty_cache() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
nemo_rl/algorithms/grpo.py
(4 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
nemo_rl/algorithms/grpo.py
1239-1239: Local variable collection_task
is assigned to but never used
Remove assignment to unused variable collection_task
(F841)
1257-1257: Do not catch blind exception: Exception
(BLE001)
1268-1268: Do not catch blind exception: Exception
(BLE001)
1296-1296: Do not catch blind exception: Exception
(BLE001)
1439-1441: Avoid specifying long messages outside the exception class
(TRY003)
1500-1500: Loop control variable j
not used within loop body
Rename unused j
to _j
(B007)
1640-1640: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
1735-1735: Do not catch blind exception: Exception
(BLE001)
1740-1740: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Lint check
🔇 Additional comments (1)
nemo_rl/algorithms/grpo.py (1)
1670-1672
: Guard offload_after_refit for non‑colocated setupsThis method’s semantics are tied to colocated refits.
- policy.offload_after_refit() + if colocated_inference: + policy.offload_after_refit()
nemo_rl/algorithms/grpo.py
Outdated
assert _should_use_async_rollouts(master_config), ( | ||
"Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. " | ||
"Set policy.generation.vllm_cfg.async_engine to true in your config." | ||
) | ||
assert loss_fn["use_importance_sampling_correction"] is True, ( | ||
"Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use loss-fn config for IS correction check (current code will TypeError)
loss_fn
is a callable/object, not a dict. Check master_config['loss_fn']
(or an attribute on loss_fn
if exposed).
- assert loss_fn["use_importance_sampling_correction"] is True, (
+ assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, (
"Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!"
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
assert _should_use_async_rollouts(master_config), ( | |
"Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. " | |
"Set policy.generation.vllm_cfg.async_engine to true in your config." | |
) | |
assert loss_fn["use_importance_sampling_correction"] is True, ( | |
"Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!" | |
) | |
assert _should_use_async_rollouts(master_config), ( | |
"Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. " | |
"Set policy.generation.vllm_cfg.async_engine to true in your config." | |
) | |
assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, ( | |
"Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!" | |
) |
🤖 Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 1128 to 1134, the code asserts
loss_fn["use_importance_sampling_correction"] but loss_fn is a callable/object
which causes a TypeError; update the assertion to read the configuration from
master_config['loss_fn'] (e.g.
master_config['loss_fn']['use_importance_sampling_correction']) or, if the
loss_fn object exposes an attribute, use getattr(loss_fn,
'use_importance_sampling_correction', False) to check truthiness, and keep the
same descriptive error message if the flag is not enabled.
# Create training data | ||
train_data = BatchedDataDict[ClippedPGLossDataDict]( | ||
{ | ||
"input_ids": flat_messages["token_ids"], | ||
"input_lengths": input_lengths, | ||
"advantages": flat_messages["advantages"], | ||
"generation_logprobs": flat_messages["generation_logprobs"], | ||
"token_mask": flat_messages["token_loss_mask"], | ||
"sample_mask": repeated_batch["loss_multiplier"], | ||
} | ||
) | ||
train_data.to("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parity with sync path: include multimodal fields in train_data
Async path omits get_multimodal_dict
, which breaks multimodal inputs.
train_data = BatchedDataDict[ClippedPGLossDataDict](
{
"input_ids": flat_messages["token_ids"],
"input_lengths": input_lengths,
"advantages": flat_messages["advantages"],
"generation_logprobs": flat_messages["generation_logprobs"],
"token_mask": flat_messages["token_loss_mask"],
"sample_mask": repeated_batch["loss_multiplier"],
}
)
+ # Keep multimodal parity with sync flow
+ train_data.update(flat_messages.get_multimodal_dict(as_tensors=False))
train_data.to("cpu")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Create training data | |
train_data = BatchedDataDict[ClippedPGLossDataDict]( | |
{ | |
"input_ids": flat_messages["token_ids"], | |
"input_lengths": input_lengths, | |
"advantages": flat_messages["advantages"], | |
"generation_logprobs": flat_messages["generation_logprobs"], | |
"token_mask": flat_messages["token_loss_mask"], | |
"sample_mask": repeated_batch["loss_multiplier"], | |
} | |
) | |
train_data.to("cpu") | |
# Create training data | |
train_data = BatchedDataDict[ClippedPGLossDataDict]( | |
{ | |
"input_ids": flat_messages["token_ids"], | |
"input_lengths": input_lengths, | |
"advantages": flat_messages["advantages"], | |
"generation_logprobs": flat_messages["generation_logprobs"], | |
"token_mask": flat_messages["token_loss_mask"], | |
"sample_mask": repeated_batch["loss_multiplier"], | |
} | |
) | |
# Keep multimodal parity with sync flow | |
train_data.update(flat_messages.get_multimodal_dict(as_tensors=False)) | |
train_data.to("cpu") |
🤖 Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 1526 to 1537, the async training path
builds train_data without the multimodal fields, breaking multimodal inputs;
update the dict construction to include the result of
get_multimodal_dict(flat_messages) (merged into the mapping passed to
BatchedDataDict[ClippedPGLossDataDict]) so the same multimodal keys present on
the sync path are included, and ensure any multimodal tensors are also moved to
CPU via train_data.to("cpu") as done for the other fields.
Signed-off-by: Parth Chadha <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/algorithms/grpo.py (1)
986-992
: Bug: references dpo in GRPO validate pathThis will KeyError if val_dataloader is None. Use grpo.
Apply:
- if val_dataloader is None: - assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + if val_dataloader is None: + assert val_dataloader is not None or master_config["grpo"]["val_period"] == 0, ( "val_dataloader is None, so dpo.val_period must be 0" )
♻️ Duplicate comments (5)
nemo_rl/algorithms/async_utils.py (3)
43-51
: ReplayBuffer advances “last_target” too early — causes partial targets and stallsAdvancing last_target_weight_already_generated on the first push lets the collector skip remaining groups for that target. Enforce a per‑target quota and advance only after the quota is met.
Apply:
-class ReplayBuffer: +class ReplayBuffer: @@ - def __init__(self, max_size: int): + def __init__(self, max_size: int, groups_per_target: int): self.max_size = max_size + self.groups_per_target = groups_per_target self.trajectories = [] @@ self.last_target_weight_already_generated = -1 self._lock = _threading.Lock() + self._target_counts: dict[int, int] = {}
53-81
: Add per‑target quota and late‑advance semantics; return “complete” when quota metPrevents over/under‑fill and aligns reservation release with full completion for a target.
Apply:
def push_with_wait_signal( @@ - with self._lock: + with self._lock: if len(self.trajectories) >= self.max_size: return "full" print("🔍 ReplayBuffer.push_with_wait_signal: Adding trajectory") - self.trajectories.append(trajectory) - self.trajectory_versions.append(weight_version) - self.target_weight_versions.append(target_weight_version) - self.last_target_weight_already_generated = max( - self.last_target_weight_already_generated, target_weight_version - ) + current = self._target_counts.get(target_weight_version, 0) + if current >= self.groups_per_target: + return "redundant" + self.trajectories.append(trajectory) + self.trajectory_versions.append(weight_version) + self.target_weight_versions.append(target_weight_version) + self._target_counts[target_weight_version] = current + 1 + # Advance only when quota is satisfied + if self._target_counts[target_weight_version] >= self.groups_per_target: + self.last_target_weight_already_generated = max( + self.last_target_weight_already_generated, target_weight_version + ) print( f"ReplayBuffer state: {len(self.trajectories)} groups, versions={self.trajectory_versions}, targets={self.target_weight_versions}, last_target_weight_already_generated={self.last_target_weight_already_generated}" ) - return "success" + return ( + "complete" + if self._target_counts[target_weight_version] >= self.groups_per_target + else "success" + )
609-643
: Release reservations only when a target is fully bufferedCurrent code releases on the first prompt, risking permanent underfill if others fail.
Apply:
- if status == "success": + if status in ("success", "complete"): print( f"📦 Buffered per-prompt group (prompt_idx {prompt_idx}, target_weight {target_weight_version})" ) - - # Release reservation when FIRST prompt group for this target is successfully buffered - if prompt_idx == 0: - with self._generation_check_lock: - if target_weight_version in self._generating_targets: - self._generating_targets.discard( - target_weight_version - ) - print( - f"🧹 Released reservation for target weight {target_weight_version} (first prompt buffered)" - ) + if status == "complete": + with self._generation_check_lock: + if target_weight_version in self._generating_targets: + self._generating_targets.discard(target_weight_version) + print(f"🧹 Released reservation for target weight {target_weight_version} (all prompts buffered)") breaknemo_rl/algorithms/grpo.py (2)
1204-1207
: Pass per‑target quota into ReplayBufferRequired for the quota/late‑advance semantics above.
Apply:
- replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( - max_size=optimal_buffer_size - ) + replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote( + max_size=optimal_buffer_size, + groups_per_target=num_prompts_per_step, + )
1507-1519
: Async path misses multimodal fields present in sync pathParity fix to keep multimodal inputs working.
Apply:
train_data = BatchedDataDict[ClippedPGLossDataDict]( { "input_ids": flat_messages["token_ids"], "input_lengths": input_lengths, "advantages": flat_messages["advantages"], "generation_logprobs": flat_messages["generation_logprobs"], "token_mask": flat_messages["token_loss_mask"], "sample_mask": repeated_batch["loss_multiplier"], } ) + # Keep multimodal parity with sync flow + train_data.update(flat_messages.get_multimodal_dict(as_tensors=False)) train_data.to("cpu")
🧹 Nitpick comments (8)
examples/configs/grpo_math_8B.yaml (1)
7-9
: Clarify async prerequisites and tune default ageAsync requires both vLLM async_engine=true and IS correction enabled; consider documenting this inline and defaulting age to >=2 if your convergence plots support it.
Apply:
async_grpo: - enabled: false - max_trajectory_age_steps: 1 + enabled: false + # Requires: policy.generation.vllm_cfg.async_engine=true and loss_fn.use_importance_sampling_correction=true + max_trajectory_age_steps: 1 # Increase to 2–4 if targeting higher throughput and your runs remain stableexamples/configs/grpo_math_1B.yaml (1)
28-30
: Auto‑wire IS correction to async toggleAvoids drift between config and the runtime assert.
Apply:
- # Async GRPO requires importance sampling correction enabled - # Set to true when async_grpo.enabled is true - use_importance_sampling_correction: false + # Async GRPO requires importance sampling correction enabled + # Auto-enable when async_grpo is enabled + use_importance_sampling_correction: ${grpo.async_grpo.enabled}nemo_rl/algorithms/async_utils.py (3)
317-336
: Avoid duplicate targeting; consult existing targets setIf the process restarts or quotas change, use get_existing_target_weights to skip duplicates, not just last_target.
Apply:
- last_target_weight_already_generated = ray.get( - self.replay_buffer.get_last_target_weight_already_generated.remote() - ) + last_target_weight_already_generated, existing_targets = ray.get( + [ + self.replay_buffer.get_last_target_weight_already_generated.remote(), + self.replay_buffer.get_existing_target_weights.remote(), + ] + ) @@ - if ( + if ( target_weight > last_target_weight_already_generated - and target_weight not in self._generating_targets + and target_weight not in self._generating_targets + and target_weight not in existing_targets ):
404-415
: Minor: unify target list logging with calculator to avoid off‑by‑one confusionUse _calculate_target_weights for the log instead of rebuilding with range(max_age).
Apply:
- target_weights = [ - self.current_weight_version + i - for i in range(max_trajectory_age) - ] + target_weights = self._calculate_target_weights(self.current_weight_version)
370-371
: Narrow broad excepts and keep messages on the exceptionCatching bare Exception obscures actionable errors and trips linters. Catch specific errors or at least log e with repr and re‑raise where appropriate.
Also applies to: 436-436, 501-501, 643-643, 648-648, 669-669
nemo_rl/algorithms/grpo.py (3)
1481-1497
: Rename unused loop variable jMinor lint fix.
Apply:
- for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): + for i, message_log in enumerate(repeated_batch["message_log"]): + for _j, message in enumerate(message_log):And similarly in the sync loop.
Also applies to: 699-715
1621-1624
: warnings.warn: set stacklevel for actionable source linesImproves navigation from logs.
Apply:
- warnings.warn( + warnings.warn( f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " "Saving most recent k checkpoints instead." - ) + , stacklevel=2)
1652-1653
: Only offload_after_refit for colocated inferenceAvoid potential no‑op or unexpected behavior when non‑colocated.
Apply:
- policy.offload_after_refit() + if colocated_inference: + policy.offload_after_refit()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/configs/grpo_math_1B.yaml
(2 hunks)examples/configs/grpo_math_8B.yaml
(1 hunks)examples/run_grpo_math.py
(1 hunks)nemo_rl/algorithms/async_utils.py
(1 hunks)nemo_rl/algorithms/grpo.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/run_grpo_math.py
🧰 Additional context used
🧬 Code graph analysis (2)
nemo_rl/algorithms/grpo.py (6)
nemo_rl/utils/timer.py (1)
time
(110-123)nemo_rl/distributed/ray_actor_environment_registry.py (1)
get_actor_python_env
(47-62)nemo_rl/utils/venvs.py (1)
create_local_venv_on_each_node
(152-189)nemo_rl/algorithms/async_utils.py (12)
AsyncTrajectoryCollector
(237-672)ReplayBuffer
(36-233)start_collection
(373-384)set_weight_version
(338-347)pause
(510-513)resume
(515-518)size
(223-226)sample
(100-221)get_debug_info
(82-89)prepare_for_refit
(520-538)resume_after_refit
(540-543)get_dataloader_state
(567-571)nemo_rl/distributed/batched_data_dict.py (4)
size
(793-802)BatchedDataDict
(75-839)from_batches
(102-151)to
(804-811)nemo_rl/data/llm_message_utils.py (1)
batched_message_log_to_flat_message
(233-390)
nemo_rl/algorithms/async_utils.py (6)
nemo_rl/algorithms/grpo.py (1)
MasterConfig
(121-129)nemo_rl/data/interfaces.py (1)
DatumSpec
(32-40)nemo_rl/distributed/batched_data_dict.py (2)
BatchedDataDict
(75-839)repeat_interleave
(703-724)nemo_rl/environments/interfaces.py (1)
EnvironmentInterface
(52-88)nemo_rl/experience/rollouts.py (1)
run_async_multi_turn_rollout
(751-895)nemo_rl/models/generation/interfaces.py (1)
GenerationInterface
(208-242)
🪛 Ruff (0.12.2)
nemo_rl/algorithms/grpo.py
1239-1239: Local variable collection_task
is assigned to but never used
Remove assignment to unused variable collection_task
(F841)
1257-1257: Do not catch blind exception: Exception
(BLE001)
1268-1268: Do not catch blind exception: Exception
(BLE001)
1296-1296: Do not catch blind exception: Exception
(BLE001)
1420-1422: Avoid specifying long messages outside the exception class
(TRY003)
1481-1481: Loop control variable j
not used within loop body
Rename unused j
to _j
(B007)
1621-1621: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
1716-1716: Do not catch blind exception: Exception
(BLE001)
1721-1721: Do not catch blind exception: Exception
(BLE001)
nemo_rl/algorithms/async_utils.py
141-143: Avoid specifying long messages outside the exception class
(TRY003)
369-369: Consider moving this statement to an else
block
(TRY300)
370-370: Do not catch blind exception: Exception
(BLE001)
436-436: Do not catch blind exception: Exception
(BLE001)
501-501: Do not catch blind exception: Exception
(BLE001)
643-643: Do not catch blind exception: Exception
(BLE001)
648-648: Do not catch blind exception: Exception
(BLE001)
669-669: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Lint check
🔇 Additional comments (3)
examples/configs/grpo_math_1B.yaml (1)
16-19
: Async block: add explicit cross‑refs to required flagsA brief note here reduces misconfigurations when toggling async.
[raise_nitpick_issue]
Apply:async_grpo: - enabled: false # Set to true to enable async training mode + enabled: false # Set to true to enable async training mode (requires vLLM async_engine and IS correction) # Max age (in training steps) for trajectories used in training max_trajectory_age_steps: 1nemo_rl/algorithms/grpo.py (2)
1238-1240
: Remove unused variable assignmentAvoids F841 and clarifies that start_collection is fire‑and‑forget.
[raise_nitpick_issue]
Apply:- collection_task = trajectory_collector.start_collection.remote(dataloader) + trajectory_collector.start_collection.remote(dataloader)
1176-1194
: Verify Ray runtime_env usage of py_executableEnsure your Ray version honors runtime_env['py_executable']; otherwise actors may launch with the driver Python. Affected: nemo_rl/algorithms/grpo.py:1176–1194 (ReplayBuffer). Also review nemo_rl/distributed/worker_groups.py and nemo_rl/distributed/ray_actor_environment_registry.py.
Sandbox could not import ray (ModuleNotFoundError); run locally to confirm Ray version and behavior:
- python -c "import ray; print(ray.version)"
- Minimal test: create a tiny actor with .options(runtime_env={'py_executable':'/path/to/python'}) that returns sys.executable and verify it matches the provided py_executable.
print(f" {self.trajectory_versions=}") | ||
|
||
# For debugging: check for unexpected old trajectories | ||
from collections import Counter | ||
|
||
version_counts = Counter(self.trajectory_versions) | ||
print(f" {version_counts=}") | ||
|
||
# Compute minimum valid version based on age window | ||
# max_age_steps=1 means trajectories from the last 1 step are valid | ||
min_valid_version = max(0, current_weight_version - max_age_steps) | ||
print(f" {min_valid_version=}") | ||
|
||
# Check for unexpected old trajectories | ||
old_trajectories = [ | ||
v for v in self.trajectory_versions if v < min_valid_version | ||
] | ||
if old_trajectories: | ||
raise ValueError( | ||
f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}" | ||
) | ||
|
||
# Filter for valid trajectories without modifying the buffer | ||
valid_indices = [ | ||
i | ||
for i, v in enumerate(self.trajectory_versions) | ||
if min_valid_version <= v <= current_weight_version | ||
] | ||
print( | ||
f" valid_indices: {len(valid_indices)}/{total_trajectories} trajectories within age window" | ||
) | ||
if not valid_indices: | ||
print("No trajectories available for sampling.") | ||
return None | ||
|
||
# Enforce exact number of groups if available; otherwise, signal to wait | ||
if len(valid_indices) < num_prompt_groups: | ||
print( | ||
f"Insufficient valid groups: have {len(valid_indices)}, need {num_prompt_groups}. Waiting for buffer to fill." | ||
) | ||
return None | ||
|
||
# Only select trajectories intended for the current training step | ||
# This ensures no trajectory loses its "last chance" to be used for its intended step | ||
intended_indices = [ | ||
i | ||
for i in valid_indices | ||
if self.target_weight_versions[i] == current_weight_version | ||
] | ||
|
||
print( | ||
f" 🎯 Found {len(intended_indices)} trajectories intended for current step {current_weight_version}" | ||
) | ||
|
||
# Stall training if we don't have enough trajectories intended for this step | ||
if len(intended_indices) < num_prompt_groups: | ||
print( | ||
f" ⏸️ STALLING: Need {num_prompt_groups} trajectories for step {current_weight_version}, but only {len(intended_indices)} are ready" | ||
) | ||
print( | ||
f" ⏸️ Training will wait for remaining {num_prompt_groups - len(intended_indices)} trajectories to be generated" | ||
) | ||
return None | ||
|
||
# Select exactly the trajectories intended for this step (FIFO within same target) | ||
selected: list[int] = intended_indices[:num_prompt_groups] | ||
print( | ||
f" ✅ Selected {len(selected)} trajectories all intended for step {current_weight_version}" | ||
) | ||
|
||
from collections import Counter | ||
|
||
sampled_weights = [self.trajectory_versions[i] for i in selected] | ||
avg_trajectory_age = current_weight_version - sum(sampled_weights) / len( | ||
sampled_weights | ||
) | ||
print( | ||
f"✅ Selected counts by generation weight-version: {Counter(sampled_weights)}" | ||
) | ||
print(f"📊 Average trajectory age: {avg_trajectory_age:.2f} steps") | ||
print( | ||
f"🎯 All selected trajectories target step {current_weight_version} (100% target match)" | ||
) | ||
|
||
sampled_items = [self.trajectories[i] for i in selected] | ||
|
||
# Remove selected items in reverse order to maintain correct indices | ||
for idx in sorted(selected, reverse=True): | ||
self.trajectory_versions.pop(idx) | ||
self.target_weight_versions.pop(idx) | ||
self.trajectories.pop(idx) | ||
print( | ||
f"🗑️ Consumed and removed {len(selected)} groups from buffer, old buffer size: {total_trajectories}, new buffer size: {len(self.trajectories)}, new target weight versions {self.target_weight_versions}" | ||
) | ||
|
||
return { | ||
"trajectories": sampled_items, | ||
"avg_trajectory_age": avg_trajectory_age, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sampling: drop stale data instead of raising; rely on intended_indices only
Raising on stale items crashes training; also the precheck on valid_indices is redundant given intended_indices gating.
Apply:
- old_trajectories = [
- v for v in self.trajectory_versions if v < min_valid_version
- ]
- if old_trajectories:
- raise ValueError(
- f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}"
- )
+ # Drop stale entries quietly to avoid stalls/crashes
+ stale_indices = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid_version]
+ for idx in sorted(stale_indices, reverse=True):
+ self.trajectories.pop(idx)
+ self.trajectory_versions.pop(idx)
+ tw = self.target_weight_versions.pop(idx)
+ if hasattr(self, "_target_counts"):
+ self._target_counts[tw] = self._target_counts.get(tw, 1) - 1
+ if self._target_counts[tw] <= 0:
+ del self._target_counts[tw]
@@
- if len(valid_indices) < num_prompt_groups:
- print(
- f"Insufficient valid groups: have {len(valid_indices)}, need {num_prompt_groups}. Waiting for buffer to fill."
- )
- return None
+ # Do not early-return here; we stall based on intended_indices below
@@
intended_indices = [
@@
if len(intended_indices) < num_prompt_groups:
@@
for idx in sorted(selected, reverse=True):
- self.trajectory_versions.pop(idx)
- self.target_weight_versions.pop(idx)
- self.trajectories.pop(idx)
+ tw = self.target_weight_versions.pop(idx)
+ self.trajectory_versions.pop(idx)
+ self.trajectories.pop(idx)
+ if hasattr(self, "_target_counts"):
+ self._target_counts[tw] = self._target_counts.get(tw, 1) - 1
+ if self._target_counts[tw] <= 0:
+ del self._target_counts[tw]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def sample( | |
self, | |
num_prompt_groups: int, | |
current_weight_version: int, | |
max_age_steps: int, | |
) -> Optional[dict[str, Any]]: | |
"""Sample per-prompt trajectory groups intended for the current training step. | |
Only returns trajectories with target_weight_version == current_weight_version. | |
If insufficient trajectories are available, returns None to stall training | |
until the remaining trajectories are generated. This ensures no trajectory | |
loses its last chance to be used for its intended training step. | |
Returns: | |
Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data | |
""" | |
with self._lock: | |
if not self.trajectories: | |
return None | |
total_trajectories = len(self.trajectories) | |
print("🔍 ReplayBuffer sampling debug:") | |
print(f" {current_weight_version=}, {max_age_steps=}") | |
print(f" {self.trajectory_versions=}") | |
# For debugging: check for unexpected old trajectories | |
from collections import Counter | |
version_counts = Counter(self.trajectory_versions) | |
print(f" {version_counts=}") | |
# Compute minimum valid version based on age window | |
# max_age_steps=1 means trajectories from the last 1 step are valid | |
min_valid_version = max(0, current_weight_version - max_age_steps) | |
print(f" {min_valid_version=}") | |
# Check for unexpected old trajectories | |
old_trajectories = [ | |
v for v in self.trajectory_versions if v < min_valid_version | |
] | |
if old_trajectories: | |
raise ValueError( | |
f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}" | |
) | |
# Filter for valid trajectories without modifying the buffer | |
valid_indices = [ | |
i | |
for i, v in enumerate(self.trajectory_versions) | |
if min_valid_version <= v <= current_weight_version | |
] | |
print( | |
f" valid_indices: {len(valid_indices)}/{total_trajectories} trajectories within age window" | |
) | |
if not valid_indices: | |
print("No trajectories available for sampling.") | |
return None | |
# Enforce exact number of groups if available; otherwise, signal to wait | |
if len(valid_indices) < num_prompt_groups: | |
print( | |
f"Insufficient valid groups: have {len(valid_indices)}, need {num_prompt_groups}. Waiting for buffer to fill." | |
) | |
return None | |
# Only select trajectories intended for the current training step | |
# This ensures no trajectory loses its "last chance" to be used for its intended step | |
intended_indices = [ | |
i | |
for i in valid_indices | |
if self.target_weight_versions[i] == current_weight_version | |
] | |
print( | |
f" 🎯 Found {len(intended_indices)} trajectories intended for current step {current_weight_version}" | |
) | |
# Stall training if we don't have enough trajectories intended for this step | |
if len(intended_indices) < num_prompt_groups: | |
print( | |
f" ⏸️ STALLING: Need {num_prompt_groups} trajectories for step {current_weight_version}, but only {len(intended_indices)} are ready" | |
) | |
print( | |
f" ⏸️ Training will wait for remaining {num_prompt_groups - len(intended_indices)} trajectories to be generated" | |
) | |
return None | |
# Select exactly the trajectories intended for this step (FIFO within same target) | |
selected: list[int] = intended_indices[:num_prompt_groups] | |
print( | |
f" ✅ Selected {len(selected)} trajectories all intended for step {current_weight_version}" | |
) | |
from collections import Counter | |
sampled_weights = [self.trajectory_versions[i] for i in selected] | |
avg_trajectory_age = current_weight_version - sum(sampled_weights) / len( | |
sampled_weights | |
) | |
print( | |
f"✅ Selected counts by generation weight-version: {Counter(sampled_weights)}" | |
) | |
print(f"📊 Average trajectory age: {avg_trajectory_age:.2f} steps") | |
print( | |
f"🎯 All selected trajectories target step {current_weight_version} (100% target match)" | |
) | |
sampled_items = [self.trajectories[i] for i in selected] | |
# Remove selected items in reverse order to maintain correct indices | |
for idx in sorted(selected, reverse=True): | |
self.trajectory_versions.pop(idx) | |
self.target_weight_versions.pop(idx) | |
self.trajectories.pop(idx) | |
print( | |
f"🗑️ Consumed and removed {len(selected)} groups from buffer, old buffer size: {total_trajectories}, new buffer size: {len(self.trajectories)}, new target weight versions {self.target_weight_versions}" | |
) | |
return { | |
"trajectories": sampled_items, | |
"avg_trajectory_age": avg_trajectory_age, | |
} | |
def sample( | |
self, | |
num_prompt_groups: int, | |
current_weight_version: int, | |
max_age_steps: int, | |
) -> Optional[dict[str, Any]]: | |
"""Sample per-prompt trajectory groups intended for the current training step. | |
Only returns trajectories with target_weight_version == current_weight_version. | |
If insufficient trajectories are available, returns None to stall training | |
until the remaining trajectories are generated. This ensures no trajectory | |
loses its last chance to be used for its intended training step. | |
Returns: | |
Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data | |
""" | |
with self._lock: | |
if not self.trajectories: | |
return None | |
total_trajectories = len(self.trajectories) | |
print("🔍 ReplayBuffer sampling debug:") | |
print(f" {current_weight_version=}, {max_age_steps=}") | |
print(f" {self.trajectory_versions=}") | |
# For debugging: check for unexpected old trajectories | |
from collections import Counter | |
version_counts = Counter(self.trajectory_versions) | |
print(f" {version_counts=}") | |
# Compute minimum valid version based on age window | |
# max_age_steps=1 means trajectories from the last 1 step are valid | |
min_valid_version = max(0, current_weight_version - max_age_steps) | |
print(f" {min_valid_version=}") | |
# Drop stale entries quietly to avoid stalls/crashes | |
stale_indices = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid_version] | |
for idx in sorted(stale_indices, reverse=True): | |
self.trajectories.pop(idx) | |
self.trajectory_versions.pop(idx) | |
tw = self.target_weight_versions.pop(idx) | |
if hasattr(self, "_target_counts"): | |
self._target_counts[tw] = self._target_counts.get(tw, 1) - 1 | |
if self._target_counts[tw] <= 0: | |
del self._target_counts[tw] | |
# Filter for valid trajectories without modifying the buffer | |
valid_indices = [ | |
i | |
for i, v in enumerate(self.trajectory_versions) | |
if min_valid_version <= v <= current_weight_version | |
] | |
print( | |
f" valid_indices: {len(valid_indices)}/{total_trajectories} trajectories within age window" | |
) | |
if not valid_indices: | |
print("No trajectories available for sampling.") | |
return None | |
# Do not early-return here; we stall based on intended_indices below | |
# Only select trajectories intended for the current training step | |
# This ensures no trajectory loses its "last chance" to be used for its intended step | |
intended_indices = [ | |
i | |
for i in valid_indices | |
if self.target_weight_versions[i] == current_weight_version | |
] | |
print( | |
f" 🎯 Found {len(intended_indices)} trajectories intended for current step {current_weight_version}" | |
) | |
# Stall training if we don't have enough trajectories intended for this step | |
if len(intended_indices) < num_prompt_groups: | |
print( | |
f" ⏸️ STALLING: Need {num_prompt_groups} trajectories for step {current_weight_version}, but only {len(intended_indices)} are ready" | |
) | |
print( | |
f" ⏸️ Training will wait for remaining {num_prompt_groups - len(intended_indices)} trajectories to be generated" | |
) | |
return None | |
# Select exactly the trajectories intended for this step (FIFO within same target) | |
selected: list[int] = intended_indices[:num_prompt_groups] | |
print( | |
f" ✅ Selected {len(selected)} trajectories all intended for step {current_weight_version}" | |
) | |
from collections import Counter | |
sampled_weights = [self.trajectory_versions[i] for i in selected] | |
avg_trajectory_age = current_weight_version - sum(sampled_weights) / len( | |
sampled_weights | |
) | |
print( | |
f"✅ Selected counts by generation weight-version: {Counter(sampled_weights)}" | |
) | |
print(f"📊 Average trajectory age: {avg_trajectory_age:.2f} steps") | |
print( | |
f"🎯 All selected trajectories target step {current_weight_version} (100% target match)" | |
) | |
sampled_items = [self.trajectories[i] for i in selected] | |
# Remove selected items in reverse order to maintain correct indices | |
for idx in sorted(selected, reverse=True): | |
tw = self.target_weight_versions.pop(idx) | |
self.trajectory_versions.pop(idx) | |
self.trajectories.pop(idx) | |
if hasattr(self, "_target_counts"): | |
self._target_counts[tw] = self._target_counts.get(tw, 1) - 1 | |
if self._target_counts[tw] <= 0: | |
del self._target_counts[tw] | |
print( | |
f"🗑️ Consumed and removed {len(selected)} groups from buffer, old buffer size: {total_trajectories}, new buffer size: {len(self.trajectories)}, new target weight versions {self.target_weight_versions}" | |
) | |
return { | |
"trajectories": sampled_items, | |
"avg_trajectory_age": avg_trajectory_age, | |
} |
🧰 Tools
🪛 Ruff (0.12.2)
141-143: Avoid specifying long messages outside the exception class
(TRY003)
What does this PR do ?
This PR adds support training GRPO with async setup. In an async setup generation and training workers are non-colocated (distinct physical resources) and generation workers can generate training data for future training workers. In the current implementation every generation worker generates data for a future training worker deterministically (the generation worker knows the number of previous generated samples and can compute the targeted training worker N steps ahead).
Current PR allows user to set
max_trajectory_age_steps
which means how much stale data is allowed for a trainer.Async RL is only stable when importance sampling correction is enabled (proof for clipping TBD). If importance sampling is disabled you will see the training to collapse after a few 100 iterations.

We have observed convergence results similar to AReal and Prime-RL with async convergence matching sync convergence upto 8-16 steps old data. (more plots TBD).
Note that in this PR we wait for all on-going generations to complete when a refit (update weights) request comes in. This is not optimal for performance and causes huge inefficiencies in the training pipeline. Support for non-blocking refit will come in a separate PR.
Performance data (will be improved after ^ is addressed):
TBD
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit