Skip to content

Conversation

parthchadha
Copy link
Contributor

@parthchadha parthchadha commented Sep 8, 2025

What does this PR do ?

This PR adds support training GRPO with async setup. In an async setup generation and training workers are non-colocated (distinct physical resources) and generation workers can generate training data for future training workers. In the current implementation every generation worker generates data for a future training worker deterministically (the generation worker knows the number of previous generated samples and can compute the targeted training worker N steps ahead).

Current PR allows user to set max_trajectory_age_steps which means how much stale data is allowed for a trainer.

Async RL is only stable when importance sampling correction is enabled (proof for clipping TBD). If importance sampling is disabled you will see the training to collapse after a few 100 iterations.
Screenshot 2025-09-08 at 3 38 48 PM

We have observed convergence results similar to AReal and Prime-RL with async convergence matching sync convergence upto 8-16 steps old data. (more plots TBD).

Note that in this PR we wait for all on-going generations to complete when a refit (update weights) request comes in. This is not optimal for performance and causes huge inefficiencies in the training pipeline. Support for non-blocking refit will come in a separate PR.

Performance data (will be improved after ^ is addressed):
TBD

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Optional asynchronous GRPO training with background trajectory collection; falls back to synchronous when disabled.
  • Improvements
    • Better GPU memory cleanup after validation to reduce memory pressure.
  • Configuration
    • New grpo.async_grpo block: enabled (bool) and max_trajectory_age_steps (int) to control async behavior.
  • Documentation
    • Added guidance on importance sampling correction settings in relation to async GRPO.

parthchadha and others added 23 commits July 30, 2025 23:30
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
Signed-off-by: Parth Chadha <[email protected]>
…1035)

Signed-off-by: Rahul Chand <[email protected]>
Signed-off-by: Youngeun Kwon <[email protected]>
Co-authored-by: Rahul Chand <[email protected]>
Co-authored-by: Youngeun Kwon <[email protected]>
Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work @parthchadha . Do you think you could add maybe a sequence diagram like this https://docs.mermaidchart.com/mermaid-oss/syntax/sequenceDiagram.html#actors somewhere in the docs to help understand the flow of data/stalling? still reviewing, but this is an initial pass

also, does this close #600 ? if so, could you add it to the PR description so it'll auto close when completed?

print(f" - train_global_batch_size: {train_gbs}")
print(f" - min_trajectories_needed: {min_trajectories_needed} (async mode)")

_replay_py_exec = get_actor_python_env(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did using RayWorkerGroup not work for this use-case? too much boilerplate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it didn't work, don't recall the error but will try again and update the code if it works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder if it's due to the same issue @RayenTian faced in the reward model pr:

image

https://github.com/NVIDIA-NeMo/RL/pull/1026/files

# Clean up
print("🛑 Stopping trajectory collection...")
try:
ray.get(trajectory_collector.stop.remote())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you call stop and it set()s all the threading events, but kill the actor right after, it seems like the amount of time elapsed is negligible, so the ray.kill following this may still be "disruptive". so is this "stop" even needed?

i can see this being useful for unit tests though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, can delete stop and just kill because anyway we are not storing the trajectories in a serialized form for further use right now so its fine to just kill the trajectory collector.

self.max_size = max_size
self.trajectories = []
self.trajectory_versions = [] # weight-version used for generation
self.target_weight_versions = [] # weight-version this trajectory is targeted for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment somewhere to help disambiguate these two versions?

@terrykong
Copy link
Contributor

took a stab:

sequenceDiagram
    autonumber
    participant Trainer
    participant AsyncCollector as AsyncTrajectoryCollector
    participant Dataloader
    participant Policy as PolicyGeneration
    participant Envs as Environments
    participant Buffer as ReplayBuffer

    Trainer->>AsyncCollector: set_weight_version(t)
    Note right of AsyncCollector: g = current generation_weight_version = t

    par Continuous collection
        AsyncCollector->>Dataloader: next batch
        AsyncCollector->>AsyncCollector: target_weights = _calculate_target_weights(g)
        AsyncCollector->>Buffer: get_last_target_weight_already_generated()
        Buffer-->>AsyncCollector: last_tgt
        AsyncCollector->>AsyncCollector: tgt = _get_next_target_for_generation(g)\n(reserve if tgt > last_tgt and not in-flight)

        alt tgt found
            loop for each prompt in batch
                AsyncCollector->>Policy: run_async_multi_turn_rollout(repeated prompt)
                Policy->>Envs: step through rollout
                Envs-->>Policy: transitions
                Policy-->>AsyncCollector: trajectory_group
                AsyncCollector->>Buffer: push_with_wait_signal(group, weight_version=g, target_weight_version=tgt)
                Buffer->>Buffer: trajectories += group
                Buffer->>Buffer: trajectory_versions += g
                Buffer->>Buffer: target_weight_versions += tgt
                Buffer->>Buffer: last_target_weight_already_generated = max(last, tgt)
                Buffer-->>AsyncCollector: "success" | "full"
            end
        else no tgt available
            AsyncCollector-->>AsyncCollector: pause/wait (all targets covered or reserved)
        end
    and Training
        loop until enough groups for step t
            Trainer->>Buffer: sample(num_groups, current_weight_version=t, max_age)
            Buffer->>Buffer: min_valid = max(0, t - max_age)
            Buffer->>Buffer: valid_indices: min_valid <= generation g <= t
            Buffer->>Buffer: intended_indices: target_weight_versions == t
            alt enough intended groups
                Buffer-->>Trainer: trajectories + avg_trajectory_age
                Buffer->>Buffer: remove selected entries
            else insufficient
                Buffer-->>Trainer: None (stall until more for target t)
            end
        end
    end

    Note over AsyncCollector: _calculate_target_weights(g)\n- If g == initial_weight_version:\n  [initial, ..., initial+max_age]\n  (include current step at cold start)\n- Else: [g+1, ..., g+max_age]\n  (only future targets once warm)
Loading

Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

Adds an optional asynchronous GRPO training path. New async utilities (ReplayBuffer, AsyncTrajectoryCollector) with Ray actors manage background trajectory collection. GRPO gains async_grpo_train and memory cleanup. Entry script toggles async vs sync based on config. Configs gain grpo.async_grpo fields. Ray actor environment registry updated for new actors.

Changes

Cohort / File(s) Summary
Entry toggle and configs
examples/run_grpo_math.py, examples/configs/grpo_math_1B.yaml, examples/configs/grpo_math_8B.yaml
Adds config block grpo.async_grpo with enabled and max_trajectory_age_steps. Entry script conditionally calls async_grpo_train (keyword args) or grpo_train (sync). Passes grpo_save_state and max_trajectory_age_steps to async path.
Async utilities
nemo_rl/algorithms/async_utils.py
Introduces ReplayBuffer (Ray actor) with push/sample/query APIs and AsyncTrajectoryCollector coordinating multi-turn rollouts, buffering, pause/resume, refit, and state checkpointing using threads, locks, events, and semaphores.
GRPO algorithm
nemo_rl/algorithms/grpo.py
Adds async_grpo_train implementing asynchronous training with Ray actors, weight-version coordination, validation pause, checkpointing, and GC/cuda cache cleanup. Adds _should_use_async_rollouts helper; preserves sync flow.
Ray env mapping
nemo_rl/distributed/ray_actor_environment_registry.py
Maps AsyncTrajectoryCollector and ReplayBuffer to VLLM Python environment in ACTOR_ENVIRONMENT_REGISTRY.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Runner as run_grpo_math.py
  participant GRPO as nemo_rl.algorithms.grpo
  Note over Runner,GRPO: Training entry

  User->>Runner: Launch with config
  Runner->>Runner: Read grpo.async_grpo.enabled
  alt async enabled
    Runner->>GRPO: async_grpo_train(..., grpo_save_state, max_trajectory_age_steps)
  else sync
    Runner->>GRPO: grpo_train(...)
  end
Loading
sequenceDiagram
  autonumber
  participant Trainer as async_grpo_train
  participant RB as ReplayBuffer (Ray)
  participant Collector as AsyncTrajectoryCollector (Ray)
  participant Policy as Policy/Generation
  participant Env as Envs
  participant Val as Validator
  Note over Trainer: Initialization
  Trainer->>RB: start actor
  Trainer->>Collector: start actor (policy, tokenizer, envs)
  Trainer->>Collector: set_weight_version(v0)
  loop training steps
    Note over Collector,Env: Background rollout
    Collector->>Env: run_async_multi_turn_rollout
    Env-->>Collector: trajectories
    Collector->>RB: push_with_wait_signal(groups, gen_w, target_w)
    Note over Trainer,RB: Sampling
    Trainer->>RB: sample(num_groups, current_w, max_age)
    RB-->>Trainer: groups or None
    alt got sample
      Trainer->>Policy: logprob/inference
      Policy-->>Trainer: scores
      Trainer->>Trainer: compute baselines/advantages, optimize
      opt optional refit
        Trainer->>Collector: pause/prepare_for_refit
        Trainer->>Policy: update weights (version++)
        Trainer->>Collector: set_weight_version/new, resume_after_refit
      end
    else stall
      Trainer->>Trainer: wait/backoff
    end
    opt periodic validation
      Trainer->>Collector: pause
      Trainer->>Val: run validation
      Val-->>Trainer: metrics
      Trainer->>Trainer: gc.collect(), torch.cuda.empty_cache()
      Trainer->>Collector: resume
    end
    opt checkpoint
      Trainer->>Trainer: save state
    end
  end
  Note over Trainer,RB: Cleanup on exit
  Trainer->>Collector: stop/cleanup
  Trainer->>RB: clear/stop
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

Poem

A rabbit taps its async feet,
Buffers brim with fresh-baked treats.
Rollouts hop, then trainers chew,
Versions tick—a carrot new.
When pause for checks, we sniff the breeze,
Then race ahead with lightning ease.
GRPO hums through data trees.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "feat: add async RL support" is concise and accurately reflects the primary purpose of the changeset—adding asynchronous RL support (async GRPO training path, async utilities, replay buffer, and config flags)—so it is relevant and clear for a reviewer scanning PR history.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch faster-strictfifo

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (1)
nemo_rl/algorithms/grpo.py (1)

1127-1129: Fix: incorrect check for importance sampling; loss_fn is not subscriptable

This asserts into the callable loss object; use the loss config instead.

Apply:

-    assert loss_fn["use_importance_sampling_correction"] is True, (
+    assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, (
         "Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!"
     )
🧹 Nitpick comments (22)
examples/configs/async_grpo_math_1B.yaml (4)

7-7: Fix stale comment for trajectory age.

Comment says “last 4 training steps” but value is 1. Align the comment with the value.

-  max_trajectory_age_steps: 1  # Allow trajectories from the last 4 training steps
+  max_trajectory_age_steps: 1  # Allow trajectories from the last 1 training step

34-41: Normalize boolean casing for consistency.

Mix of True/False and true/false. Prefer one style (repo convention) for readability.

-    cpu_offload: False
+    cpu_offload: false
-    activation_checkpointing: false
+    activation_checkpointing: false
-    enabled: True
+    enabled: true
-    enabled: False
+    enabled: false
-      enforce_eager: False
+      enforce_eager: false

Also applies to: 43-47, 58-66, 70-70, 88-95, 96-109


58-64: Optional: double-check vLLM lengths.

max_new_tokens equals max_total_sequence_length and vLLM max_model_len equals the same. Depending on prompt length, this can cap generations early or waste headroom. Consider setting max_new_tokens = max_total_sequence_length - max_input_seq_length (or a fixed cap) if that matches your runtime assumptions.

Also applies to: 49-56


82-82: Clean up trailing spaces and add newline at EOF.

Yamllint flagged trailing spaces and missing newline.

-# Environment configuration  
+# Environment configuration
@@
-    flush_interval: 10 
-\ No newline at end of file
+    flush_interval: 10
+

Also applies to: 108-109

examples/configs/async_grpo_math_8B.yaml (3)

21-21: Clarify generation batch-size note.

Comment says “Only used when generating using HF backend” while backend is vLLM. Either drop the note or indicate it’s ignored with vLLM to avoid confusion.

-  generation_batch_size: 32 # Only used when generating using HF backend
+  generation_batch_size: 32 # Ignored when using vLLM backend

Also applies to: 60-76


27-35: Normalize boolean casing for consistency.

Unify True/False to repo-preferred casing.

-    cpu_offload: False
+    cpu_offload: false
-  dynamic_batching:
-    enabled: False
+  dynamic_batching:
+    enabled: false
-      enforce_eager: False
+      enforce_eager: false

Also applies to: 36-38, 69-75


59-59: Remove trailing spaces and add newline at EOF.

Yamllint warnings.

-  
+
@@
-    flush_interval: 10 
-\ No newline at end of file
+    flush_interval: 10
+

Also applies to: 97-98

nemo_rl/distributed/ray_actor_environment_registry.py (1)

39-42: ReplayBuffer likely doesn’t need vLLM environment.

Unless it imports vLLM symbols directly, mapping ReplayBuffer to PY_EXECUTABLES.VLLM increases dependency surface and startup time. Prefer PY_EXECUTABLES.BASE or SYSTEM. Keep AsyncTrajectoryCollector under VLLM if it handles vLLM exceptions.

-    # ReplayBuffer needs vLLM environment to handle trajectory data from VllmGenerationWorker
-    "nemo_rl.algorithms.async_utils.ReplayBuffer": PY_EXECUTABLES.VLLM,
+    # ReplayBuffer is transport-only; avoid vLLM dependency bloat
+    "nemo_rl.algorithms.async_utils.ReplayBuffer": PY_EXECUTABLES.BASE,

If exceptions from vLLM are serialized through Ray and require import on the receiver, keep it as-is; otherwise prefer BASE.

examples/run_grpo_math.py (2)

281-295: Pass kwargs to grpo_train for signature stability.

Synchronous path uses positional args; safer to pass by name to avoid breakage if the signature evolves.

-        grpo_train(
-            policy,
-            policy_generation,
-            dataloader,
-            val_dataloader,
-            tokenizer,
-            loss_fn,
-            task_to_env,
-            val_task_to_env,
-            logger,
-            checkpointer,
-            grpo_state,
-            master_config,
-        )
+        grpo_train(
+            policy=policy,
+            policy_generation=policy_generation,
+            dataloader=dataloader,
+            val_dataloader=val_dataloader,
+            tokenizer=tokenizer,
+            loss_fn=loss_fn,
+            task_to_env=task_to_env,
+            val_task_to_env=val_task_to_env,
+            logger=logger,
+            checkpointer=checkpointer,
+            grpo_save_state=grpo_state,
+            master_config=master_config,
+        )

225-233: Guard for generation config presence is good; consider asserting IS correction for async.

async_grpo_train asserts importance-sampling correction; you could proactively warn in the runner if async_grpo.enabled and loss config disables it.

nemo_rl/algorithms/grpo.py (8)

1160-1169: Unused variable and redundant print

train_gbs is never used; drop it. Also the duplicate “num_generations_per_prompt/samples_per_prompt_group” lines say the same thing.

-    train_gbs = master_config["policy"]["train_global_batch_size"]
@@
-    print(f"   - train_global_batch_size: {train_gbs}")

1234-1234: Remove unused assignment

Variable is never used.

-    collection_task = trajectory_collector.start_collection.remote(dataloader)
+    trajectory_collector.start_collection.remote(dataloader)

1307-1316: wait_iterations never increments

If you keep the debug loop, increment or drop the counter.

-    wait_iterations = 0
+    wait_iterations = 0
@@
-        # wait_iterations += 1
+        wait_iterations += 1

1432-1436: Assertion message length

Minor: prefer a concise message or a custom exception to satisfy TRY003.

-    raise AssertionError(
-        f"Configuration error: (num_prompts_per_step * num_generations_per_prompt) = {expected_batch_size} must be divisible by data_parallel size {dp_size}."
-    )
+    raise AssertionError(
+        f"Train batch ({expected_batch_size}) must be divisible by DP size ({dp_size})."
+    )

1495-1511: Rename unused loop var

j unused.

-                        for j, message in enumerate(message_log):
+                        for _j, message in enumerate(message_log):

1606-1611: Redundant import and GPU mem cleanup

gc already imported at file top. Keep cleanup, drop local import.

-                    import gc
-
                     gc.collect()
                     torch.cuda.empty_cache()

1635-1640: warnings.warn without stacklevel

Add stacklevel=2 for actionable locations.

-                            warnings.warn(
+                            warnings.warn(
                                 f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
                                 "Saving most recent k checkpoints instead."
-                            )
+                            , stacklevel=2)

1441-1449: Prompt-only advantages in async path vs. full-input in sync path

Async computes baselines from prompt-only tokens (good). Sync path still uses all tokens; this may cause a behavior divergence.

Unify both paths to prompt-only baseline (or document why they differ).

Also applies to: 1501-1507

nemo_rl/algorithms/async_utils.py (4)

424-429: Broad exception handling

Catching bare Exception repeatedly obscures failures. Consider narrowing (ValueError/RuntimeError) or re-raising after logging.

If you must keep broad catches in actors, at least log the full stack and include target/prompt ids for triage (you already print tracebacks).

Also applies to: 489-494, 631-636, 636-641, 657-661


258-266: Unused _pg_lock

Not used; remove to reduce noise.

-        self._pg_lock: _threading.Lock = _threading.Lock()

95-99: get_existing_target_weights currently unused

Keep if you plan to expose telemetry; otherwise remove.


134-146: Raising on “old trajectories” may crash long runs

Turning this into a warning + purge is gentler, now that per-target quotas prevent accumulation.

-            if old_trajectories:
-                raise ValueError(
-                    f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}"
-                )
+            if old_trajectories:
+                print(f"⚠️ Dropping {len(old_trajectories)} old trajectories (< {min_valid_version})")
+                # optional: actually purge them here
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f0588dc and 6868407.

📒 Files selected for processing (6)
  • examples/configs/async_grpo_math_1B.yaml (1 hunks)
  • examples/configs/async_grpo_math_8B.yaml (1 hunks)
  • examples/run_grpo_math.py (1 hunks)
  • nemo_rl/algorithms/async_utils.py (1 hunks)
  • nemo_rl/algorithms/grpo.py (4 hunks)
  • nemo_rl/distributed/ray_actor_environment_registry.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
  • PY_EXECUTABLES (42-58)
nemo_rl/algorithms/async_utils.py (6)
nemo_rl/algorithms/grpo.py (1)
  • MasterConfig (121-129)
nemo_rl/data/interfaces.py (1)
  • DatumSpec (32-40)
nemo_rl/distributed/batched_data_dict.py (2)
  • BatchedDataDict (75-839)
  • repeat_interleave (703-724)
nemo_rl/environments/interfaces.py (1)
  • EnvironmentInterface (52-88)
nemo_rl/experience/rollouts.py (1)
  • run_async_multi_turn_rollout (751-895)
nemo_rl/models/generation/interfaces.py (1)
  • GenerationInterface (208-242)
examples/run_grpo_math.py (1)
nemo_rl/algorithms/grpo.py (2)
  • async_grpo_train (1090-1738)
  • grpo_train (513-972)
nemo_rl/algorithms/grpo.py (8)
nemo_rl/utils/timer.py (4)
  • time (110-123)
  • Timer (22-248)
  • get_timing_metrics (196-233)
  • reset (235-248)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
  • get_actor_python_env (47-62)
nemo_rl/utils/venvs.py (1)
  • create_local_venv_on_each_node (152-189)
nemo_rl/utils/logger.py (2)
  • Logger (710-933)
  • log_batched_dict_as_jsonl (804-828)
nemo_rl/algorithms/async_utils.py (12)
  • AsyncTrajectoryCollector (239-660)
  • ReplayBuffer (36-235)
  • start_collection (362-373)
  • set_weight_version (327-336)
  • pause (498-501)
  • resume (503-506)
  • size (225-228)
  • sample (100-223)
  • get_debug_info (82-89)
  • prepare_for_refit (508-526)
  • resume_after_refit (528-531)
  • get_dataloader_state (555-559)
nemo_rl/distributed/batched_data_dict.py (4)
  • size (793-802)
  • BatchedDataDict (75-839)
  • from_batches (102-151)
  • to (804-811)
nemo_rl/data/llm_message_utils.py (1)
  • batched_message_log_to_flat_message (233-390)
nemo_rl/algorithms/utils.py (1)
  • calculate_baseline_and_std_per_prompt (47-116)
🪛 YAMLlint (1.37.1)
examples/configs/async_grpo_math_1B.yaml

[error] 82-82: trailing spaces

(trailing-spaces)


[error] 108-108: trailing spaces

(trailing-spaces)


[error] 109-109: no new line character at the end of file

(new-line-at-end-of-file)


[error] 109-109: trailing spaces

(trailing-spaces)

examples/configs/async_grpo_math_8B.yaml

[error] 59-59: trailing spaces

(trailing-spaces)


[error] 97-97: trailing spaces

(trailing-spaces)


[error] 98-98: no new line character at the end of file

(new-line-at-end-of-file)


[error] 98-98: trailing spaces

(trailing-spaces)

🪛 Ruff (0.12.2)
nemo_rl/algorithms/async_utils.py

143-145: Avoid specifying long messages outside the exception class

(TRY003)


358-358: Consider moving this statement to an else block

(TRY300)


359-359: Do not catch blind exception: Exception

(BLE001)


424-424: Do not catch blind exception: Exception

(BLE001)


489-489: Do not catch blind exception: Exception

(BLE001)


631-631: Do not catch blind exception: Exception

(BLE001)


636-636: Do not catch blind exception: Exception

(BLE001)


657-657: Do not catch blind exception: Exception

(BLE001)

nemo_rl/algorithms/grpo.py

1234-1234: Local variable collection_task is assigned to but never used

Remove assignment to unused variable collection_task

(F841)


1252-1252: Do not catch blind exception: Exception

(BLE001)


1263-1263: Do not catch blind exception: Exception

(BLE001)


1291-1291: Do not catch blind exception: Exception

(BLE001)


1434-1436: Avoid specifying long messages outside the exception class

(TRY003)


1495-1495: Loop control variable j not used within loop body

Rename unused j to _j

(B007)


1635-1635: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


1730-1730: Do not catch blind exception: Exception

(BLE001)


1735-1735: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (13)
examples/configs/async_grpo_math_1B.yaml (1)

2-2: Defaults resolution path verified: load_config_with_inheritance wraps a string defaults value into a list before merging parent configs, so defaults: "grpo_math_1B.yaml" is correctly applied.

examples/configs/async_grpo_math_8B.yaml (2)

2-2: Verify defaults resolution path (same as 1B config).

Ensure defaults: "grpo_math_1B.yaml" is actually merged by your loader.

Use the same script provided on the 1B config comment.


18-23: Sanity-check effective batch vs hardware.

train_global_batch_size: 512, train_micro_batch_size: 1, gpus_per_node: 8, num_nodes: 1 implies very high accumulation. Confirm divisibility and memory headroom at 4096 seq length; otherwise consider reducing global batch or enabling dynamic batching.

Also applies to: 81-84

examples/run_grpo_math.py (1)

255-277: LGTM: clean async toggle and call surface.

Conditional path, clear print, and explicit kwargs into async_grpo_train look good.

nemo_rl/algorithms/grpo.py (5)

1151-1153: Colocated inference unsupported in async path — good to assert

Clear early failure mode. Consider improving the message with a hint to set policy.generation.colocated.enabled=false.


1273-1300: Pause/resume around initial validation — good guardrails

Pausing collection avoids pressure during validation. Nice.


1666-1667: Verify offload_after_refit at end of checkpoint

This method is intended for refit. Using it post-checkpoint may evict state unexpectedly if the next step runs immediately.

Would you confirm it’s safe here for Megatron/vLLM? If not, consider removing it in the checkpoint path.


1083-1086: GPU memory cleanup after validation — looks good

Makes OOMs less likely during long async runs.


1424-1436: No action required: BatchedDataDict.size is an @Property and calls to .size (no parentheses) are correct.

nemo_rl/algorithms/async_utils.py (4)

282-289: Target selection and pause logic — LGTM

Deterministic future targeting with a bounded age window and backpressure via last-target is sound given the per-target quota fix.

Also applies to: 306-325, 338-361, 392-414


438-471: num_prompts = batch.size — property vs method

Same note as main file: ensure .size is a property, or call batch.size().

See script in grpo.py comment.


561-566: Thread set maintenance — good

Cleanup avoids unbounded growth of _inflight_threads.


367-374: Daemon collection thread and clean return log — LGTM

Actor lifecycle is clean; driver kills actors in finally.

Comment on lines +35 to +51
@ray.remote
class ReplayBuffer:
"""Replay buffer storing per-prompt groups.

A single entry corresponds to 1 prompt repeated by
grpo.num_generations_per_prompt (required to compute per-prompt advantages).
"""

def __init__(self, max_size: int):
self.max_size = max_size
self.trajectories = []
# If trajectory_version is 1 and target_weight_version is 4 it means that weight version 1 was used for generating a trajectory and this trajectory will be used for training when weight version is 4.
self.trajectory_versions = [] # it is the weight-version used for generation of a trajectory
self.target_weight_versions = [] # it is the weight-version of the trainer where this trajectory will be used.

self.last_target_weight_already_generated = -1
self._lock = _threading.Lock()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add per-target quotas and late-advance semantics to ReplayBuffer

Enforce exactly groups_per_target groups per target, and only mark last_target_weight_already_generated when the quota is met. Prevents overfill/underfill and avoids stale leftovers that later trigger the “old trajectory” error.

Apply:

- class ReplayBuffer:
+ class ReplayBuffer:
@@
-    def __init__(self, max_size: int):
+    def __init__(self, max_size: int, groups_per_target: int):
         self.max_size = max_size
+        self.groups_per_target = groups_per_target
         self.trajectories = []
@@
         self.last_target_weight_already_generated = -1
         self._lock = _threading.Lock()
+        self._target_counts: dict[int, int] = {}
@@
-    ) -> str:
+    ) -> str:
@@
-            self.trajectories.append(trajectory)
-            self.trajectory_versions.append(weight_version)
-            self.target_weight_versions.append(target_weight_version)
-            self.last_target_weight_already_generated = max(
-                self.last_target_weight_already_generated, target_weight_version
-            )
+            current = self._target_counts.get(target_weight_version, 0)
+            if current >= self.groups_per_target:
+                return "redundant"
+            self.trajectories.append(trajectory)
+            self.trajectory_versions.append(weight_version)
+            self.target_weight_versions.append(target_weight_version)
+            self._target_counts[target_weight_version] = current + 1
+            # Advance only when quota is satisfied
+            if self._target_counts[target_weight_version] >= self.groups_per_target:
+                self.last_target_weight_already_generated = max(
+                    self.last_target_weight_already_generated, target_weight_version
+                )
@@
-            sampled_items = [self.trajectories[i] for i in selected]
+            sampled_items = [self.trajectories[i] for i in selected]
@@
             for idx in sorted(selected, reverse=True):
-                self.trajectory_versions.pop(idx)
-                self.target_weight_versions.pop(idx)
-                self.trajectories.pop(idx)
+                tw = self.target_weight_versions.pop(idx)
+                self.trajectory_versions.pop(idx)
+                self.trajectories.pop(idx)
+                self._target_counts[tw] = self._target_counts.get(tw, 1) - 1
+                if self._target_counts[tw] <= 0:
+                    del self._target_counts[tw]
@@
         with self._lock:
             return len(self.trajectories)

Additionally, return "complete" when quota is met:

-            return "success"
+            return "complete" if self._target_counts[target_weight_version] >= self.groups_per_target else "success"

Also applies to: 82-90, 95-99, 100-223, 225-236

Comment on lines +610 to +623
f"📦 Buffered per-prompt group (prompt_idx {prompt_idx}, target_weight {target_weight_version})"
)

# Release reservation when FIRST prompt group for this target is successfully buffered
if prompt_idx == 0:
with self._generation_check_lock:
if target_weight_version in self._generating_targets:
self._generating_targets.discard(
target_weight_version
)
print(
f"🧹 Released reservation for target weight {target_weight_version} (first prompt buffered)"
)
break
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Reservation released on first push — can stall training on partial failures

Releasing _generating_targets at prompt_idx == 0 advances scheduling even if not all groups are buffered. If some threads fail, training will wait forever for the missing groups, and the collector won’t regenerate them.

Fix by making ReplayBuffer enforce per-target counts and only “advance” a target after groups_per_target pushes. Have push_with_wait_signal return "complete" when it reaches the quota; release the reservation only on "complete".

-                    if status == "success":
+                    if status in ("success", "complete"):
                         print(
                             f"📦 Buffered per-prompt group (prompt_idx {prompt_idx}, target_weight {target_weight_version})"
                         )
-
-                        # Release reservation when FIRST prompt group for this target is successfully buffered
-                        if prompt_idx == 0:
-                            with self._generation_check_lock:
-                                if target_weight_version in self._generating_targets:
-                                    self._generating_targets.discard(
-                                        target_weight_version
-                                    )
-                                    print(
-                                        f"🧹 Released reservation for target weight {target_weight_version} (first prompt buffered)"
-                                    )
+                        if status == "complete":
+                            with self._generation_check_lock:
+                                if target_weight_version in self._generating_targets:
+                                    self._generating_targets.discard(target_weight_version)
+                                    print(f"🧹 Released reservation for target weight {target_weight_version} (all prompts buffered)")
                         break

See ReplayBuffer changes below.

Also applies to: 642-649

🤖 Prompt for AI Agents
In nemo_rl/algorithms/async_utils.py around lines 610-623 (and similarly
642-649), the code releases the _generating_targets reservation as soon as
prompt_idx == 0 which can advance scheduling prematurely and stall training if
other groups never buffer; change the logic so ReplayBuffer enforces per-target
group counts and push_with_wait_signal returns "complete" only when
groups_per_target pushes for that target have been received, then only discard
target_weight_version from _generating_targets when push_with_wait_signal
returns "complete" (i.e., replace the prompt_idx==0 check with checking the
return value and release reservation only on "complete").

Comment on lines 1144 to 1147
step = grpo_save_state["step"]
weight_version = step # Tracks refitted weight versions
consumed_samples = grpo_save_state["consumed_samples"]
val_period = master_config["grpo"]["val_period"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix: inconsistent training state keys; grpo_save_state['step'] will KeyError

Save-state schema uses current_step/total_steps. Unify reads/writes to avoid crashes and incompatible checkpoints.

Apply:

-    step = grpo_save_state["step"]
+    step = grpo_save_state.get("current_step", 0)
@@
-                    grpo_save_state["step"] = step + 1
+                    grpo_save_state["current_step"] = step + 1

Also applies to: 1623-1629

Comment on lines +1190 to +1202
# Calculate optimal buffer size based on generation limits to prevent length bias
# Each weight version generates exactly num_prompts_per_step trajectories
# With max_age_steps, we keep trajectories from multiple weight versions
num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"]
late_arrival_slack = 2
optimal_buffer_size = (
num_prompts_per_step * max_trajectory_age_steps * late_arrival_slack
)

replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
max_size=optimal_buffer_size
)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

ReplayBuffer ctor lacks per-target capacity; risk of over/under-fill and staleness

Without per-target quotas, a partial or duplicated target can permanently stall sampling or create stale leftovers. Pass groups-per-target and enforce it in the actor (see async_utils diff).

Apply here after updating ReplayBuffer (see other file):

-    replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
-        max_size=optimal_buffer_size
-    )
+    replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
+        max_size=optimal_buffer_size,
+        groups_per_target=num_prompts_per_step,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Calculate optimal buffer size based on generation limits to prevent length bias
# Each weight version generates exactly num_prompts_per_step trajectories
# With max_age_steps, we keep trajectories from multiple weight versions
num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"]
late_arrival_slack = 2
optimal_buffer_size = (
num_prompts_per_step * max_trajectory_age_steps * late_arrival_slack
)
replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
max_size=optimal_buffer_size
)
replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
max_size=optimal_buffer_size,
groups_per_target=num_prompts_per_step,
)
🤖 Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 1190 to 1202, the ReplayBuffer is
instantiated without a per-target quota which can lead to over/under-fill and
stale entries; add a groups_per_target parameter to the actor creation call
using the GRPO config (e.g. groups_per_target =
master_config["grpo"]["groups_per_target"] or computed from the GRPO settings)
and pass it into ReplayBuffer.options(...).remote(max_size=optimal_buffer_size,
groups_per_target=groups_per_target) so the actor can enforce per-target
capacity (the actor-side enforcement is handled in the async_utils update).

try:
# Run rollout for this prompt group
# Async engine supports concurrent generation; avoid locking
final_batch, rollout_metrics = run_async_multi_turn_rollout(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdyt about timing this block and aggregating? i'm wondering if it's possible to measure the "generation bubble". Since with enough generation servers it seems like you could have lots of idle time maybe signaling you have too many generation servers or too low max_trajectory_age_steps

Comment on lines 4 to 11
# Async-specific settings
async_grpo:
enabled: true # Enable async training
max_trajectory_age_steps: 1 # Allow trajectories from the last 1 training steps

grpo:
num_prompts_per_step: 64
num_generations_per_prompt: 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was there an advantage to maintaining another config? wdyt about merging this config with grpo_math_1B.yaml and moving the async args under the algo root (default to false)? then putting these two recipes into examples/configs/recipes/llm/ & tests/test_suites/llm/ so we can track these nightly?

Suggested change
# Async-specific settings
async_grpo:
enabled: true # Enable async training
max_trajectory_age_steps: 1 # Allow trajectories from the last 1 training steps
grpo:
num_prompts_per_step: 64
num_generations_per_prompt: 32
grpo:
num_prompts_per_step: 64
num_generations_per_prompt: 32
# Async-specific settings
async_cfg:
enabled: true # Enable async training
max_trajectory_age_steps: 1 # Allow trajectories from the last 1 training steps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, moved under grpo and removed async config examples.


# Wait for initial buffer fill
print(
f"⏳ Waiting for replay buffer to have sufficient trajectories (min={min_trajectories_needed})..."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"⏳ Waiting for replay buffer to have sufficient trajectories (min={min_trajectories_needed})..."
f"⏳ Waiting for replay buffer to have sufficient trajectories ({min_trajectories_needed=})..."

to avoid being conflated with minutes

buffer_size_current = ray.get(replay_buffer.size.remote())

print(
f" Wait iteration {wait_iterations}: buffer_size={buffer_size_current}/{min_trajectories_needed}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f" Wait iteration {wait_iterations}: buffer_size={buffer_size_current}/{min_trajectories_needed}"
f" Wait iteration {wait_iterations}: buffer_filled_ratio={buffer_size_current}/{min_trajectories_needed}"

if buffer_size_current >= min_trajectories_needed:
break

# wait_iterations += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this comment is intentional. Do you mind leaving a comment if so?

f" Trajectory versions in buffer: {buffer_debug['trajectory_versions']}"
)

time.sleep(0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible the prints above this could get spammy if it is stuck in this loop for a while?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does but in practice I found these to be useful and not too spammy. We can revisit later if we want to remove these.

total_trajectories = len(self.trajectories)
print("🔍 ReplayBuffer sampling debug:")
print(
f" current_weight_version={current_weight_version}, max_age_steps={max_age_steps}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f" current_weight_version={current_weight_version}, max_age_steps={max_age_steps}"
f" {current_weight_version=}, {max_age_steps=}"

# Compute minimum valid version based on age window
# max_age_steps=1 means trajectories from the last 1 step are valid
min_valid_version = max(0, current_weight_version - max_age_steps)
print(f" min_valid_version={min_valid_version}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(f" min_valid_version={min_valid_version}")
print(f" {min_valid_version=}")

save a few characters :)

self._generating_targets: set[int] = set()

def _calculate_target_weights(self, generation_weight_version: int) -> list[int]:
"""Calculate target weight versions for given generation weight version."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to workshop this, but from our offline discussion:

Suggested change
"""Calculate target weight versions for given generation weight version."""
"""Calculate target weight versions for given generation weight version.
The list of versions returned enumerate the possible version a generation
server can target. These versions are looped over to see what training
step they can target. If all target versions are exhausted, this generation
server will remain idle until the next weight update.
Example:
generation_weight_version = 10
max_trajectory_age_steps = 4
Returns:
[11, 12, 13, 14] # Meaning this generation server can create trajectories for training step 11, 12, 13, 14
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just an FYI, the sync grpo loop has multi-epoch support. keep that for a future PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will add in a separate PR.

# Import async utilities only when needed
from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer

timer = Timer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another feat missing in this loop is the timeout checker:

timeout = TimeoutChecker(

re: b74c5d0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the next PR, don't want to add more core changes to algo before this version is checked in.

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parthchadha were you planning on adding:

  1. a doc page
  2. unit tests
  3. function tests (ex: https://github.com/NVIDIA-NeMo/RL/blob/main/tests/functional/L1_Functional_Tests_GPU.sh)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/algorithms/grpo.py (1)

986-992: Fix config key: use grpo.val_period in validate()

master_config["dpo"]["val_period"] is wrong for GRPO and can assert incorrectly.

-    if val_dataloader is None:
-        assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
+    if val_dataloader is None:
+        assert val_dataloader is not None or master_config["grpo"]["val_period"] == 0, (
             "val_dataloader is None, so dpo.val_period must be 0"
         )
♻️ Duplicate comments (4)
nemo_rl/algorithms/grpo.py (4)

1195-1206: Pass per‑target quota to ReplayBuffer to avoid stalls and staleness

Add groups_per_target=num_prompts_per_step so the actor can enforce per-target capacity.

-    replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
-        max_size=optimal_buffer_size
-    )
+    replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
+        max_size=optimal_buffer_size,
+        groups_per_target=num_prompts_per_step,
+    )

1308-1321: Clarify wait message and increment counter; throttle logs

  • “min=” can be misread as minutes.
  • wait_iterations never increments.
  • Prevent spam by logging every N iterations.
-    print(
-        f"⏳ Waiting for replay buffer to have sufficient trajectories (min={min_trajectories_needed})..."
-    )
+    print(
+        f"⏳ Waiting for replay buffer to have sufficient trajectories (min_trajectories_needed={min_trajectories_needed})..."
+    )
     wait_iterations = 0
     while True:
         buffer_size_current = ray.get(replay_buffer.size.remote())
-
-        print(
-            f"  Wait iteration {wait_iterations}: buffer_size={buffer_size_current}/{min_trajectories_needed}"
-        )
+        if wait_iterations % 10 == 0:
+            print(
+                f"  Wait iteration {wait_iterations}: buffer_filled_ratio={buffer_size_current}/{min_trajectories_needed}"
+            )
+        wait_iterations += 1

1138-1145: Add timeout-based checkpointing like sync loop

Prevents losing progress when a save deadline is near.

     timer = Timer()
+    timeout = TimeoutChecker(
+        timeout=master_config["checkpointing"]["checkpoint_must_save_by"],
+        fit_last_save_time=True,
+    )
+    timeout.start_iterations()
-                # Checkpointing (same as sync version)
-                consumed_samples += master_config["grpo"]["num_prompts_per_step"]
-                if master_config["checkpointing"]["enabled"] and (
-                    is_last_step
-                    or (step + 1) % master_config["checkpointing"]["save_period"] == 0
-                ):
+                # Checkpointing (+ timeout)
+                consumed_samples += master_config["grpo"]["num_prompts_per_step"]
+                timeout.mark_iteration()
+                should_save_by_timeout = timeout.check_save()
+                if master_config["checkpointing"]["enabled"] and (
+                    is_last_step
+                    or (step + 1) % master_config["checkpointing"]["save_period"] == 0
+                    or should_save_by_timeout
+                ):
-            timer.reset()
-            step += 1
+            timer.reset()
+            step += 1
+            if 'should_save_by_timeout' in locals() and should_save_by_timeout:
+                break

Also applies to: 1621-1626, 1727-1730


1148-1152: Unify save‑state schema: current_step vs step (prevents KeyError and incompatible checkpoints)

GRPOSaveState defines current_step/total_steps. Accessing/setting step will break resume.

-    step = grpo_save_state["step"]
+    step = grpo_save_state.get("current_step", 0)
-                    grpo_save_state["step"] = step + 1
+                    grpo_save_state["current_step"] = step + 1
+                    grpo_save_state["total_steps"] = grpo_save_state.get("total_steps", 0) + 1

Also applies to: 1628-1634

🧹 Nitpick comments (6)
nemo_rl/algorithms/grpo.py (6)

687-697: Avoid KeyError for optional overlong_filtering

Field is NotRequired; default to False.

-                    use_overlong_filtering = master_config["grpo"]["overlong_filtering"]
+                    use_overlong_filtering = master_config["grpo"].get("overlong_filtering", False)

1238-1240: Remove unused variable (Ruff F841) or make intent explicit

Assignment isn’t used; keep a throwaway to signal intentional fire‑and‑forget.

-    collection_task = trajectory_collector.start_collection.remote(dataloader)
+    _ = trajectory_collector.start_collection.remote(dataloader)

1499-1516: Rename unused loop var j (Ruff B007)

No functional change.

-                    for i, message_log in enumerate(repeated_batch["message_log"]):
-                        for j, message in enumerate(message_log):
+                    for i, message_log in enumerate(repeated_batch["message_log"]):
+                        for _j, message in enumerate(message_log):

1407-1444: Assertion message is long (TRY003); optional tidy

Minor: shorten or rely on the default message to satisfy linters.


1281-1305: Broad exception catches; prefer narrower or re‑raise with context

Catch specific Ray/Runtime errors; avoid swallowing unrelated failures.


1088-1091: GPU memory cleanup: optionally gate on CUDA availability

Harmless as-is; consider guarding to avoid surprises on CPU-only runs.

-    gc.collect()
-    torch.cuda.empty_cache()
+    gc.collect()
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6868407 and c8ee01f.

📒 Files selected for processing (1)
  • nemo_rl/algorithms/grpo.py (4 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
nemo_rl/algorithms/grpo.py

1239-1239: Local variable collection_task is assigned to but never used

Remove assignment to unused variable collection_task

(F841)


1257-1257: Do not catch blind exception: Exception

(BLE001)


1268-1268: Do not catch blind exception: Exception

(BLE001)


1296-1296: Do not catch blind exception: Exception

(BLE001)


1439-1441: Avoid specifying long messages outside the exception class

(TRY003)


1500-1500: Loop control variable j not used within loop body

Rename unused j to _j

(B007)


1640-1640: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


1735-1735: Do not catch blind exception: Exception

(BLE001)


1740-1740: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Lint check
🔇 Additional comments (1)
nemo_rl/algorithms/grpo.py (1)

1670-1672: Guard offload_after_refit for non‑colocated setups

This method’s semantics are tied to colocated refits.

-                    policy.offload_after_refit()
+                    if colocated_inference:
+                        policy.offload_after_refit()

Comment on lines 1128 to 1134
assert _should_use_async_rollouts(master_config), (
"Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. "
"Set policy.generation.vllm_cfg.async_engine to true in your config."
)
assert loss_fn["use_importance_sampling_correction"] is True, (
"Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use loss-fn config for IS correction check (current code will TypeError)

loss_fn is a callable/object, not a dict. Check master_config['loss_fn'] (or an attribute on loss_fn if exposed).

-    assert loss_fn["use_importance_sampling_correction"] is True, (
+    assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, (
         "Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!"
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert _should_use_async_rollouts(master_config), (
"Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. "
"Set policy.generation.vllm_cfg.async_engine to true in your config."
)
assert loss_fn["use_importance_sampling_correction"] is True, (
"Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!"
)
assert _should_use_async_rollouts(master_config), (
"Async GRPO requires vLLM backend with vllm_cfg.async_engine=True. "
"Set policy.generation.vllm_cfg.async_engine to true in your config."
)
assert master_config["loss_fn"]["use_importance_sampling_correction"] is True, (
"Importance sampling correction must be enabled for async GRPO for good convergence due to off-policy samples!"
)
🤖 Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 1128 to 1134, the code asserts
loss_fn["use_importance_sampling_correction"] but loss_fn is a callable/object
which causes a TypeError; update the assertion to read the configuration from
master_config['loss_fn'] (e.g.
master_config['loss_fn']['use_importance_sampling_correction']) or, if the
loss_fn object exposes an attribute, use getattr(loss_fn,
'use_importance_sampling_correction', False) to check truthiness, and keep the
same descriptive error message if the flag is not enabled.

Comment on lines +1526 to +1537
# Create training data
train_data = BatchedDataDict[ClippedPGLossDataDict](
{
"input_ids": flat_messages["token_ids"],
"input_lengths": input_lengths,
"advantages": flat_messages["advantages"],
"generation_logprobs": flat_messages["generation_logprobs"],
"token_mask": flat_messages["token_loss_mask"],
"sample_mask": repeated_batch["loss_multiplier"],
}
)
train_data.to("cpu")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Parity with sync path: include multimodal fields in train_data

Async path omits get_multimodal_dict, which breaks multimodal inputs.

                     train_data = BatchedDataDict[ClippedPGLossDataDict](
                         {
                             "input_ids": flat_messages["token_ids"],
                             "input_lengths": input_lengths,
                             "advantages": flat_messages["advantages"],
                             "generation_logprobs": flat_messages["generation_logprobs"],
                             "token_mask": flat_messages["token_loss_mask"],
                             "sample_mask": repeated_batch["loss_multiplier"],
                         }
                     )
+                    # Keep multimodal parity with sync flow
+                    train_data.update(flat_messages.get_multimodal_dict(as_tensors=False))
                     train_data.to("cpu")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Create training data
train_data = BatchedDataDict[ClippedPGLossDataDict](
{
"input_ids": flat_messages["token_ids"],
"input_lengths": input_lengths,
"advantages": flat_messages["advantages"],
"generation_logprobs": flat_messages["generation_logprobs"],
"token_mask": flat_messages["token_loss_mask"],
"sample_mask": repeated_batch["loss_multiplier"],
}
)
train_data.to("cpu")
# Create training data
train_data = BatchedDataDict[ClippedPGLossDataDict](
{
"input_ids": flat_messages["token_ids"],
"input_lengths": input_lengths,
"advantages": flat_messages["advantages"],
"generation_logprobs": flat_messages["generation_logprobs"],
"token_mask": flat_messages["token_loss_mask"],
"sample_mask": repeated_batch["loss_multiplier"],
}
)
# Keep multimodal parity with sync flow
train_data.update(flat_messages.get_multimodal_dict(as_tensors=False))
train_data.to("cpu")
🤖 Prompt for AI Agents
In nemo_rl/algorithms/grpo.py around lines 1526 to 1537, the async training path
builds train_data without the multimodal fields, breaking multimodal inputs;
update the dict construction to include the result of
get_multimodal_dict(flat_messages) (merged into the mapping passed to
BatchedDataDict[ClippedPGLossDataDict]) so the same multimodal keys present on
the sync path are included, and ensure any multimodal tensors are also moved to
CPU via train_data.to("cpu") as done for the other fields.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/algorithms/grpo.py (1)

986-992: Bug: references dpo in GRPO validate path

This will KeyError if val_dataloader is None. Use grpo.

Apply:

-    if val_dataloader is None:
-        assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, (
+    if val_dataloader is None:
+        assert val_dataloader is not None or master_config["grpo"]["val_period"] == 0, (
             "val_dataloader is None, so dpo.val_period must be 0"
         )
♻️ Duplicate comments (5)
nemo_rl/algorithms/async_utils.py (3)

43-51: ReplayBuffer advances “last_target” too early — causes partial targets and stalls

Advancing last_target_weight_already_generated on the first push lets the collector skip remaining groups for that target. Enforce a per‑target quota and advance only after the quota is met.

Apply:

-class ReplayBuffer:
+class ReplayBuffer:
@@
-    def __init__(self, max_size: int):
+    def __init__(self, max_size: int, groups_per_target: int):
         self.max_size = max_size
+        self.groups_per_target = groups_per_target
         self.trajectories = []
@@
         self.last_target_weight_already_generated = -1
         self._lock = _threading.Lock()
+        self._target_counts: dict[int, int] = {}

53-81: Add per‑target quota and late‑advance semantics; return “complete” when quota met

Prevents over/under‑fill and aligns reservation release with full completion for a target.

Apply:

     def push_with_wait_signal(
@@
-        with self._lock:
+        with self._lock:
             if len(self.trajectories) >= self.max_size:
                 return "full"
 
             print("🔍 ReplayBuffer.push_with_wait_signal: Adding trajectory")
-            self.trajectories.append(trajectory)
-            self.trajectory_versions.append(weight_version)
-            self.target_weight_versions.append(target_weight_version)
-            self.last_target_weight_already_generated = max(
-                self.last_target_weight_already_generated, target_weight_version
-            )
+            current = self._target_counts.get(target_weight_version, 0)
+            if current >= self.groups_per_target:
+                return "redundant"
+            self.trajectories.append(trajectory)
+            self.trajectory_versions.append(weight_version)
+            self.target_weight_versions.append(target_weight_version)
+            self._target_counts[target_weight_version] = current + 1
+            # Advance only when quota is satisfied
+            if self._target_counts[target_weight_version] >= self.groups_per_target:
+                self.last_target_weight_already_generated = max(
+                    self.last_target_weight_already_generated, target_weight_version
+                )
             print(
                 f"ReplayBuffer state: {len(self.trajectories)} groups, versions={self.trajectory_versions}, targets={self.target_weight_versions}, last_target_weight_already_generated={self.last_target_weight_already_generated}"
             )
-            return "success"
+            return (
+                "complete"
+                if self._target_counts[target_weight_version] >= self.groups_per_target
+                else "success"
+            )

609-643: Release reservations only when a target is fully buffered

Current code releases on the first prompt, risking permanent underfill if others fail.

Apply:

-                    if status == "success":
+                    if status in ("success", "complete"):
                         print(
                             f"📦 Buffered per-prompt group (prompt_idx {prompt_idx}, target_weight {target_weight_version})"
                         )
-
-                        # Release reservation when FIRST prompt group for this target is successfully buffered
-                        if prompt_idx == 0:
-                            with self._generation_check_lock:
-                                if target_weight_version in self._generating_targets:
-                                    self._generating_targets.discard(
-                                        target_weight_version
-                                    )
-                                    print(
-                                        f"🧹 Released reservation for target weight {target_weight_version} (first prompt buffered)"
-                                    )
+                        if status == "complete":
+                            with self._generation_check_lock:
+                                if target_weight_version in self._generating_targets:
+                                    self._generating_targets.discard(target_weight_version)
+                                    print(f"🧹 Released reservation for target weight {target_weight_version} (all prompts buffered)")
                         break
nemo_rl/algorithms/grpo.py (2)

1204-1207: Pass per‑target quota into ReplayBuffer

Required for the quota/late‑advance semantics above.

Apply:

-    replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
-        max_size=optimal_buffer_size
-    )
+    replay_buffer = ReplayBuffer.options(runtime_env=_replay_runtime_env).remote(
+        max_size=optimal_buffer_size,
+        groups_per_target=num_prompts_per_step,
+    )

1507-1519: Async path misses multimodal fields present in sync path

Parity fix to keep multimodal inputs working.

Apply:

                     train_data = BatchedDataDict[ClippedPGLossDataDict](
                         {
                             "input_ids": flat_messages["token_ids"],
                             "input_lengths": input_lengths,
                             "advantages": flat_messages["advantages"],
                             "generation_logprobs": flat_messages["generation_logprobs"],
                             "token_mask": flat_messages["token_loss_mask"],
                             "sample_mask": repeated_batch["loss_multiplier"],
                         }
                     )
+                    # Keep multimodal parity with sync flow
+                    train_data.update(flat_messages.get_multimodal_dict(as_tensors=False))
                     train_data.to("cpu")
🧹 Nitpick comments (8)
examples/configs/grpo_math_8B.yaml (1)

7-9: Clarify async prerequisites and tune default age

Async requires both vLLM async_engine=true and IS correction enabled; consider documenting this inline and defaulting age to >=2 if your convergence plots support it.

Apply:

   async_grpo:
-    enabled: false
-    max_trajectory_age_steps: 1
+    enabled: false
+    # Requires: policy.generation.vllm_cfg.async_engine=true and loss_fn.use_importance_sampling_correction=true
+    max_trajectory_age_steps: 1  # Increase to 2–4 if targeting higher throughput and your runs remain stable
examples/configs/grpo_math_1B.yaml (1)

28-30: Auto‑wire IS correction to async toggle

Avoids drift between config and the runtime assert.

Apply:

-  # Async GRPO requires importance sampling correction enabled
-  # Set to true when async_grpo.enabled is true
-  use_importance_sampling_correction: false
+  # Async GRPO requires importance sampling correction enabled
+  # Auto-enable when async_grpo is enabled
+  use_importance_sampling_correction: ${grpo.async_grpo.enabled}
nemo_rl/algorithms/async_utils.py (3)

317-336: Avoid duplicate targeting; consult existing targets set

If the process restarts or quotas change, use get_existing_target_weights to skip duplicates, not just last_target.

Apply:

-        last_target_weight_already_generated = ray.get(
-            self.replay_buffer.get_last_target_weight_already_generated.remote()
-        )
+        last_target_weight_already_generated, existing_targets = ray.get(
+            [
+                self.replay_buffer.get_last_target_weight_already_generated.remote(),
+                self.replay_buffer.get_existing_target_weights.remote(),
+            ]
+        )
@@
-                if (
+                if (
                     target_weight > last_target_weight_already_generated
-                    and target_weight not in self._generating_targets
+                    and target_weight not in self._generating_targets
+                    and target_weight not in existing_targets
                 ):

404-415: Minor: unify target list logging with calculator to avoid off‑by‑one confusion

Use _calculate_target_weights for the log instead of rebuilding with range(max_age).

Apply:

-                        target_weights = [
-                            self.current_weight_version + i
-                            for i in range(max_trajectory_age)
-                        ]
+                        target_weights = self._calculate_target_weights(self.current_weight_version)

370-371: Narrow broad excepts and keep messages on the exception

Catching bare Exception obscures actionable errors and trips linters. Catch specific errors or at least log e with repr and re‑raise where appropriate.

Also applies to: 436-436, 501-501, 643-643, 648-648, 669-669

nemo_rl/algorithms/grpo.py (3)

1481-1497: Rename unused loop variable j

Minor lint fix.

Apply:

-                    for i, message_log in enumerate(repeated_batch["message_log"]):
-                        for j, message in enumerate(message_log):
+                    for i, message_log in enumerate(repeated_batch["message_log"]):
+                        for _j, message in enumerate(message_log):

And similarly in the sync loop.

Also applies to: 699-715


1621-1624: warnings.warn: set stacklevel for actionable source lines

Improves navigation from logs.

Apply:

-                            warnings.warn(
+                            warnings.warn(
                                 f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
                                 "Saving most recent k checkpoints instead."
-                            )
+                            , stacklevel=2)

1652-1653: Only offload_after_refit for colocated inference

Avoid potential no‑op or unexpected behavior when non‑colocated.

Apply:

-                    policy.offload_after_refit()
+                    if colocated_inference:
+                        policy.offload_after_refit()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c8ee01f and 62eec29.

📒 Files selected for processing (5)
  • examples/configs/grpo_math_1B.yaml (2 hunks)
  • examples/configs/grpo_math_8B.yaml (1 hunks)
  • examples/run_grpo_math.py (1 hunks)
  • nemo_rl/algorithms/async_utils.py (1 hunks)
  • nemo_rl/algorithms/grpo.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/run_grpo_math.py
🧰 Additional context used
🧬 Code graph analysis (2)
nemo_rl/algorithms/grpo.py (6)
nemo_rl/utils/timer.py (1)
  • time (110-123)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
  • get_actor_python_env (47-62)
nemo_rl/utils/venvs.py (1)
  • create_local_venv_on_each_node (152-189)
nemo_rl/algorithms/async_utils.py (12)
  • AsyncTrajectoryCollector (237-672)
  • ReplayBuffer (36-233)
  • start_collection (373-384)
  • set_weight_version (338-347)
  • pause (510-513)
  • resume (515-518)
  • size (223-226)
  • sample (100-221)
  • get_debug_info (82-89)
  • prepare_for_refit (520-538)
  • resume_after_refit (540-543)
  • get_dataloader_state (567-571)
nemo_rl/distributed/batched_data_dict.py (4)
  • size (793-802)
  • BatchedDataDict (75-839)
  • from_batches (102-151)
  • to (804-811)
nemo_rl/data/llm_message_utils.py (1)
  • batched_message_log_to_flat_message (233-390)
nemo_rl/algorithms/async_utils.py (6)
nemo_rl/algorithms/grpo.py (1)
  • MasterConfig (121-129)
nemo_rl/data/interfaces.py (1)
  • DatumSpec (32-40)
nemo_rl/distributed/batched_data_dict.py (2)
  • BatchedDataDict (75-839)
  • repeat_interleave (703-724)
nemo_rl/environments/interfaces.py (1)
  • EnvironmentInterface (52-88)
nemo_rl/experience/rollouts.py (1)
  • run_async_multi_turn_rollout (751-895)
nemo_rl/models/generation/interfaces.py (1)
  • GenerationInterface (208-242)
🪛 Ruff (0.12.2)
nemo_rl/algorithms/grpo.py

1239-1239: Local variable collection_task is assigned to but never used

Remove assignment to unused variable collection_task

(F841)


1257-1257: Do not catch blind exception: Exception

(BLE001)


1268-1268: Do not catch blind exception: Exception

(BLE001)


1296-1296: Do not catch blind exception: Exception

(BLE001)


1420-1422: Avoid specifying long messages outside the exception class

(TRY003)


1481-1481: Loop control variable j not used within loop body

Rename unused j to _j

(B007)


1621-1621: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


1716-1716: Do not catch blind exception: Exception

(BLE001)


1721-1721: Do not catch blind exception: Exception

(BLE001)

nemo_rl/algorithms/async_utils.py

141-143: Avoid specifying long messages outside the exception class

(TRY003)


369-369: Consider moving this statement to an else block

(TRY300)


370-370: Do not catch blind exception: Exception

(BLE001)


436-436: Do not catch blind exception: Exception

(BLE001)


501-501: Do not catch blind exception: Exception

(BLE001)


643-643: Do not catch blind exception: Exception

(BLE001)


648-648: Do not catch blind exception: Exception

(BLE001)


669-669: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Lint check
🔇 Additional comments (3)
examples/configs/grpo_math_1B.yaml (1)

16-19: Async block: add explicit cross‑refs to required flags

A brief note here reduces misconfigurations when toggling async.
[raise_nitpick_issue]
Apply:

   async_grpo:
-    enabled: false # Set to true to enable async training mode
+    enabled: false # Set to true to enable async training mode (requires vLLM async_engine and IS correction)
     # Max age (in training steps) for trajectories used in training
     max_trajectory_age_steps: 1
nemo_rl/algorithms/grpo.py (2)

1238-1240: Remove unused variable assignment

Avoids F841 and clarifies that start_collection is fire‑and‑forget.
[raise_nitpick_issue]
Apply:

-    collection_task = trajectory_collector.start_collection.remote(dataloader)
+    trajectory_collector.start_collection.remote(dataloader)

1176-1194: Verify Ray runtime_env usage of py_executable

Ensure your Ray version honors runtime_env['py_executable']; otherwise actors may launch with the driver Python. Affected: nemo_rl/algorithms/grpo.py:1176–1194 (ReplayBuffer). Also review nemo_rl/distributed/worker_groups.py and nemo_rl/distributed/ray_actor_environment_registry.py.

Sandbox could not import ray (ModuleNotFoundError); run locally to confirm Ray version and behavior:

  • python -c "import ray; print(ray.version)"
  • Minimal test: create a tiny actor with .options(runtime_env={'py_executable':'/path/to/python'}) that returns sys.executable and verify it matches the provided py_executable.

Comment on lines +100 to +222
print(f" {self.trajectory_versions=}")

# For debugging: check for unexpected old trajectories
from collections import Counter

version_counts = Counter(self.trajectory_versions)
print(f" {version_counts=}")

# Compute minimum valid version based on age window
# max_age_steps=1 means trajectories from the last 1 step are valid
min_valid_version = max(0, current_weight_version - max_age_steps)
print(f" {min_valid_version=}")

# Check for unexpected old trajectories
old_trajectories = [
v for v in self.trajectory_versions if v < min_valid_version
]
if old_trajectories:
raise ValueError(
f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}"
)

# Filter for valid trajectories without modifying the buffer
valid_indices = [
i
for i, v in enumerate(self.trajectory_versions)
if min_valid_version <= v <= current_weight_version
]
print(
f" valid_indices: {len(valid_indices)}/{total_trajectories} trajectories within age window"
)
if not valid_indices:
print("No trajectories available for sampling.")
return None

# Enforce exact number of groups if available; otherwise, signal to wait
if len(valid_indices) < num_prompt_groups:
print(
f"Insufficient valid groups: have {len(valid_indices)}, need {num_prompt_groups}. Waiting for buffer to fill."
)
return None

# Only select trajectories intended for the current training step
# This ensures no trajectory loses its "last chance" to be used for its intended step
intended_indices = [
i
for i in valid_indices
if self.target_weight_versions[i] == current_weight_version
]

print(
f" 🎯 Found {len(intended_indices)} trajectories intended for current step {current_weight_version}"
)

# Stall training if we don't have enough trajectories intended for this step
if len(intended_indices) < num_prompt_groups:
print(
f" ⏸️ STALLING: Need {num_prompt_groups} trajectories for step {current_weight_version}, but only {len(intended_indices)} are ready"
)
print(
f" ⏸️ Training will wait for remaining {num_prompt_groups - len(intended_indices)} trajectories to be generated"
)
return None

# Select exactly the trajectories intended for this step (FIFO within same target)
selected: list[int] = intended_indices[:num_prompt_groups]
print(
f" ✅ Selected {len(selected)} trajectories all intended for step {current_weight_version}"
)

from collections import Counter

sampled_weights = [self.trajectory_versions[i] for i in selected]
avg_trajectory_age = current_weight_version - sum(sampled_weights) / len(
sampled_weights
)
print(
f"✅ Selected counts by generation weight-version: {Counter(sampled_weights)}"
)
print(f"📊 Average trajectory age: {avg_trajectory_age:.2f} steps")
print(
f"🎯 All selected trajectories target step {current_weight_version} (100% target match)"
)

sampled_items = [self.trajectories[i] for i in selected]

# Remove selected items in reverse order to maintain correct indices
for idx in sorted(selected, reverse=True):
self.trajectory_versions.pop(idx)
self.target_weight_versions.pop(idx)
self.trajectories.pop(idx)
print(
f"🗑️ Consumed and removed {len(selected)} groups from buffer, old buffer size: {total_trajectories}, new buffer size: {len(self.trajectories)}, new target weight versions {self.target_weight_versions}"
)

return {
"trajectories": sampled_items,
"avg_trajectory_age": avg_trajectory_age,
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Sampling: drop stale data instead of raising; rely on intended_indices only

Raising on stale items crashes training; also the precheck on valid_indices is redundant given intended_indices gating.

Apply:

-            old_trajectories = [
-                v for v in self.trajectory_versions if v < min_valid_version
-            ]
-            if old_trajectories:
-                raise ValueError(
-                    f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}"
-                )
+            # Drop stale entries quietly to avoid stalls/crashes
+            stale_indices = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid_version]
+            for idx in sorted(stale_indices, reverse=True):
+                self.trajectories.pop(idx)
+                self.trajectory_versions.pop(idx)
+                tw = self.target_weight_versions.pop(idx)
+                if hasattr(self, "_target_counts"):
+                    self._target_counts[tw] = self._target_counts.get(tw, 1) - 1
+                    if self._target_counts[tw] <= 0:
+                        del self._target_counts[tw]
@@
-            if len(valid_indices) < num_prompt_groups:
-                print(
-                    f"Insufficient valid groups: have {len(valid_indices)}, need {num_prompt_groups}. Waiting for buffer to fill."
-                )
-                return None
+            # Do not early-return here; we stall based on intended_indices below
@@
             intended_indices = [
@@
             if len(intended_indices) < num_prompt_groups:
@@
             for idx in sorted(selected, reverse=True):
-                self.trajectory_versions.pop(idx)
-                self.target_weight_versions.pop(idx)
-                self.trajectories.pop(idx)
+                tw = self.target_weight_versions.pop(idx)
+                self.trajectory_versions.pop(idx)
+                self.trajectories.pop(idx)
+                if hasattr(self, "_target_counts"):
+                    self._target_counts[tw] = self._target_counts.get(tw, 1) - 1
+                    if self._target_counts[tw] <= 0:
+                        del self._target_counts[tw]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sample(
self,
num_prompt_groups: int,
current_weight_version: int,
max_age_steps: int,
) -> Optional[dict[str, Any]]:
"""Sample per-prompt trajectory groups intended for the current training step.
Only returns trajectories with target_weight_version == current_weight_version.
If insufficient trajectories are available, returns None to stall training
until the remaining trajectories are generated. This ensures no trajectory
loses its last chance to be used for its intended training step.
Returns:
Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data
"""
with self._lock:
if not self.trajectories:
return None
total_trajectories = len(self.trajectories)
print("🔍 ReplayBuffer sampling debug:")
print(f" {current_weight_version=}, {max_age_steps=}")
print(f" {self.trajectory_versions=}")
# For debugging: check for unexpected old trajectories
from collections import Counter
version_counts = Counter(self.trajectory_versions)
print(f" {version_counts=}")
# Compute minimum valid version based on age window
# max_age_steps=1 means trajectories from the last 1 step are valid
min_valid_version = max(0, current_weight_version - max_age_steps)
print(f" {min_valid_version=}")
# Check for unexpected old trajectories
old_trajectories = [
v for v in self.trajectory_versions if v < min_valid_version
]
if old_trajectories:
raise ValueError(
f"Found {len(old_trajectories)} trajectories older than min_valid_version {min_valid_version}"
)
# Filter for valid trajectories without modifying the buffer
valid_indices = [
i
for i, v in enumerate(self.trajectory_versions)
if min_valid_version <= v <= current_weight_version
]
print(
f" valid_indices: {len(valid_indices)}/{total_trajectories} trajectories within age window"
)
if not valid_indices:
print("No trajectories available for sampling.")
return None
# Enforce exact number of groups if available; otherwise, signal to wait
if len(valid_indices) < num_prompt_groups:
print(
f"Insufficient valid groups: have {len(valid_indices)}, need {num_prompt_groups}. Waiting for buffer to fill."
)
return None
# Only select trajectories intended for the current training step
# This ensures no trajectory loses its "last chance" to be used for its intended step
intended_indices = [
i
for i in valid_indices
if self.target_weight_versions[i] == current_weight_version
]
print(
f" 🎯 Found {len(intended_indices)} trajectories intended for current step {current_weight_version}"
)
# Stall training if we don't have enough trajectories intended for this step
if len(intended_indices) < num_prompt_groups:
print(
f" ⏸️ STALLING: Need {num_prompt_groups} trajectories for step {current_weight_version}, but only {len(intended_indices)} are ready"
)
print(
f" ⏸️ Training will wait for remaining {num_prompt_groups - len(intended_indices)} trajectories to be generated"
)
return None
# Select exactly the trajectories intended for this step (FIFO within same target)
selected: list[int] = intended_indices[:num_prompt_groups]
print(
f" ✅ Selected {len(selected)} trajectories all intended for step {current_weight_version}"
)
from collections import Counter
sampled_weights = [self.trajectory_versions[i] for i in selected]
avg_trajectory_age = current_weight_version - sum(sampled_weights) / len(
sampled_weights
)
print(
f"✅ Selected counts by generation weight-version: {Counter(sampled_weights)}"
)
print(f"📊 Average trajectory age: {avg_trajectory_age:.2f} steps")
print(
f"🎯 All selected trajectories target step {current_weight_version} (100% target match)"
)
sampled_items = [self.trajectories[i] for i in selected]
# Remove selected items in reverse order to maintain correct indices
for idx in sorted(selected, reverse=True):
self.trajectory_versions.pop(idx)
self.target_weight_versions.pop(idx)
self.trajectories.pop(idx)
print(
f"🗑️ Consumed and removed {len(selected)} groups from buffer, old buffer size: {total_trajectories}, new buffer size: {len(self.trajectories)}, new target weight versions {self.target_weight_versions}"
)
return {
"trajectories": sampled_items,
"avg_trajectory_age": avg_trajectory_age,
}
def sample(
self,
num_prompt_groups: int,
current_weight_version: int,
max_age_steps: int,
) -> Optional[dict[str, Any]]:
"""Sample per-prompt trajectory groups intended for the current training step.
Only returns trajectories with target_weight_version == current_weight_version.
If insufficient trajectories are available, returns None to stall training
until the remaining trajectories are generated. This ensures no trajectory
loses its last chance to be used for its intended training step.
Returns:
Dictionary with 'trajectories' and 'avg_trajectory_age' keys, or None if insufficient data
"""
with self._lock:
if not self.trajectories:
return None
total_trajectories = len(self.trajectories)
print("🔍 ReplayBuffer sampling debug:")
print(f" {current_weight_version=}, {max_age_steps=}")
print(f" {self.trajectory_versions=}")
# For debugging: check for unexpected old trajectories
from collections import Counter
version_counts = Counter(self.trajectory_versions)
print(f" {version_counts=}")
# Compute minimum valid version based on age window
# max_age_steps=1 means trajectories from the last 1 step are valid
min_valid_version = max(0, current_weight_version - max_age_steps)
print(f" {min_valid_version=}")
# Drop stale entries quietly to avoid stalls/crashes
stale_indices = [i for i, v in enumerate(self.trajectory_versions) if v < min_valid_version]
for idx in sorted(stale_indices, reverse=True):
self.trajectories.pop(idx)
self.trajectory_versions.pop(idx)
tw = self.target_weight_versions.pop(idx)
if hasattr(self, "_target_counts"):
self._target_counts[tw] = self._target_counts.get(tw, 1) - 1
if self._target_counts[tw] <= 0:
del self._target_counts[tw]
# Filter for valid trajectories without modifying the buffer
valid_indices = [
i
for i, v in enumerate(self.trajectory_versions)
if min_valid_version <= v <= current_weight_version
]
print(
f" valid_indices: {len(valid_indices)}/{total_trajectories} trajectories within age window"
)
if not valid_indices:
print("No trajectories available for sampling.")
return None
# Do not early-return here; we stall based on intended_indices below
# Only select trajectories intended for the current training step
# This ensures no trajectory loses its "last chance" to be used for its intended step
intended_indices = [
i
for i in valid_indices
if self.target_weight_versions[i] == current_weight_version
]
print(
f" 🎯 Found {len(intended_indices)} trajectories intended for current step {current_weight_version}"
)
# Stall training if we don't have enough trajectories intended for this step
if len(intended_indices) < num_prompt_groups:
print(
f" ⏸️ STALLING: Need {num_prompt_groups} trajectories for step {current_weight_version}, but only {len(intended_indices)} are ready"
)
print(
f" ⏸️ Training will wait for remaining {num_prompt_groups - len(intended_indices)} trajectories to be generated"
)
return None
# Select exactly the trajectories intended for this step (FIFO within same target)
selected: list[int] = intended_indices[:num_prompt_groups]
print(
f" ✅ Selected {len(selected)} trajectories all intended for step {current_weight_version}"
)
from collections import Counter
sampled_weights = [self.trajectory_versions[i] for i in selected]
avg_trajectory_age = current_weight_version - sum(sampled_weights) / len(
sampled_weights
)
print(
f"✅ Selected counts by generation weight-version: {Counter(sampled_weights)}"
)
print(f"📊 Average trajectory age: {avg_trajectory_age:.2f} steps")
print(
f"🎯 All selected trajectories target step {current_weight_version} (100% target match)"
)
sampled_items = [self.trajectories[i] for i in selected]
# Remove selected items in reverse order to maintain correct indices
for idx in sorted(selected, reverse=True):
tw = self.target_weight_versions.pop(idx)
self.trajectory_versions.pop(idx)
self.trajectories.pop(idx)
if hasattr(self, "_target_counts"):
self._target_counts[tw] = self._target_counts.get(tw, 1) - 1
if self._target_counts[tw] <= 0:
del self._target_counts[tw]
print(
f"🗑️ Consumed and removed {len(selected)} groups from buffer, old buffer size: {total_trajectories}, new buffer size: {len(self.trajectories)}, new target weight versions {self.target_weight_versions}"
)
return {
"trajectories": sampled_items,
"avg_trajectory_age": avg_trajectory_age,
}
🧰 Tools
🪛 Ruff (0.12.2)

141-143: Avoid specifying long messages outside the exception class

(TRY003)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants