Skip to content

feat: skip logprob and reference logprob computation under certain conditions#1891

Open
guyueh1 wants to merge 9 commits intoNVIDIA-NeMo:mainfrom
guyueh1:fuse_logprob_train
Open

feat: skip logprob and reference logprob computation under certain conditions#1891
guyueh1 wants to merge 9 commits intoNVIDIA-NeMo:mainfrom
guyueh1:fuse_logprob_train

Conversation

@guyueh1
Copy link
Contributor

@guyueh1 guyueh1 commented Feb 6, 2026

What does this PR do ?

Skip logprob and reference logprob computation under certain conditions:

  • when loss_fn.skip_reference_policy_logprobs_calculation=true, skip reference logprob. The requirement is loss_fn.reference_kl_penalty == 0 which will be checked whenever skip_reference_policy_logprobs_calculation is true.
  • when loss_fn.force_on_policy_ratio=true, skip logprob computation. The requirement is rollout batch size == train global batch size, which will be checked whenever force_on_policy_ratio is true.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features
    • Added configuration options to optimize GRPO training: ability to skip reference policy logprob calculations and enforce on-policy ratio in loss function computations.
    • New flags reduce computational overhead during training while maintaining training stability.

Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1 guyueh1 self-assigned this Feb 6, 2026
@guyueh1 guyueh1 added the deepseek Related to deepseek 671b label Feb 6, 2026
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
@guyueh1 guyueh1 marked this pull request as ready for review February 11, 2026 22:17
@guyueh1 guyueh1 requested review from a team as code owners February 11, 2026 22:17
@guyueh1 guyueh1 added the CI:L1 Run doctests, unit tests, and functional tests label Feb 11, 2026
@guyueh1 guyueh1 requested a review from HeyyyyyyG February 11, 2026 22:18
@guyueh1 guyueh1 changed the title feat: Fuse logprob and train when rollout and train have same batch size feat: skip logprob and reference logprob computation under certain conditions Feb 11, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 11, 2026

📝 Walkthrough

Walkthrough

This PR introduces configurable flags to optimize GRPO training by enabling optional skipping of logprob computations. Configuration files are updated with force_on_policy_ratio and skip_reference_policy_logprobs_calculation flags, while the algorithm implementation adds conditional gating for logprob computation in GRPO training and loss function evaluation.

Changes

Cohort / File(s) Summary
Configuration Updates
examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml, examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml
Added skip_reference_policy_logprobs_calculation: true flag under grpo section to enable skipping reference policy logprob calculations.
GRPO Recipe Configurations
examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml, examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml, examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml
Added loss_fn.force_on_policy_ratio: true configuration block under grpo section to enforce on-policy ratio in loss computation.
GRPO Algorithm Implementation
nemo_rl/algorithms/grpo.py
Added conditional gating for logprob computations based on force_on_policy_ratio and skip_reference_policy_logprobs_calculation flags; sets train_data.prev_logprobs to zeros when skipping prev_logprobs computation.
Loss Function Logic
nemo_rl/algorithms/loss_functions.py
Modified ClippedPGLossFn.__call__ to compute on-policy curr_logprobs internally when force_on_policy_ratio is enabled; handles distributed computing scenarios (vocab_parallel_group, DTensor) with appropriate vocab range and padding adjustments; sets prev_logprobs to computed curr_logprobs for on-policy behavior.

Sequence Diagram

sequenceDiagram
    participant TrainingLoop as GRPO Training Loop
    participant LogprobCalc as Logprob Calculation
    participant LossFunc as Loss Function
    participant DataPrep as Data Preparation

    TrainingLoop->>DataPrep: Prepare training data
    
    alt skip_reference_policy_logprobs_calculation == false
        DataPrep->>LogprobCalc: Compute reference_policy_logprobs
        LogprobCalc-->>DataPrep: reference_policy_logprobs
    else
        DataPrep-->>DataPrep: Skip reference logprob computation
    end
    
    alt skip_prev_logprobs == false
        DataPrep->>LogprobCalc: Compute prev_logprobs
        LogprobCalc-->>DataPrep: prev_logprobs
    else
        DataPrep-->>DataPrep: Set prev_logprobs to zeros
    end
    
    DataPrep->>LossFunc: Pass data with logprobs
    
    alt force_on_policy_ratio == true
        LossFunc->>LossFunc: Compute curr_logprobs on-policy
        LossFunc->>LossFunc: Override prev_logprobs with curr_logprobs
    else
        LossFunc->>LossFunc: Use provided prev_logprobs
    end
    
    LossFunc-->>TrainingLoop: Compute loss
    TrainingLoop->>TrainingLoop: Backpropagate
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Test Results For Major Changes ❓ Inconclusive The PR contains significant changes to loss function computation and logprob handling, but no test results or testing documentation are visible in the provided context. Search for test results documentation in PR comments, CI/CD logs, or attached test reports to confirm whether testing was performed.
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'skip logprob and reference logprob computation under certain conditions' accurately describes the main changes in the PR, which add skip flags and conditional gating for logprob computations across config files and algorithm implementations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 1579-1597: The code currently zero-fills
train_data["prev_logprobs"] when force_on_policy_ratio is True which leads to
misleading logs and plots (token_mult_prob_error); change the handling so that
when master_config["loss_fn"].get("force_on_policy_ratio", False) is True you
either (A) avoid emitting prev_logprobs into log_data and skip plotting
token_mult_prob_error, or (B) back-fill train_data["prev_logprobs"] with the
actual on-policy probabilities returned by the training step (e.g., use
train_results["curr_logprobs"] / .detach() if present) before any
logging/visualization; update the code paths around prev_logprobs,
train_results, log_data["prev_logprobs"], and token_mult_prob_error to implement
one of these behaviors.
🧹 Nitpick comments (3)
nemo_rl/algorithms/grpo.py (3)

337-337: Explain why NRL_IGNORE_TP_ACCURACY_CHECK is needed when force_on_policy_ratio is enabled.

Setting a global environment variable as a side effect of a config flag is opaque. Consider adding a comment explaining why the TP accuracy check must be disabled here, so future maintainers understand the coupling.


1602-1621: Minor: logprob_data is allocated even when both logprob computations are skipped.

When both skip_prev_logprobs and skip_reference_policy_logprobs are True, the logprob_data dict on lines 1602–1608 is created but never read. This is lightweight (just references, no tensor copies), so it's not a real concern — just a nit for clarity.


2601-2633: Duplicated skip-logic between grpo_train and async_grpo_train.

Lines 2601–2633 are nearly identical to lines 1579–1621 in grpo_train. Consider extracting the flag resolution and conditional prepare_for_lp_inference / logprob gating into a shared helper to keep both paths in sync and reduce maintenance burden.

Comment on lines +1579 to 1597
force_on_policy_ratio = master_config["loss_fn"].get(
"force_on_policy_ratio", False
)
skip_prev_logprobs = force_on_policy_ratio
skip_reference_policy_logprobs = master_config["grpo"].get(
"skip_reference_policy_logprobs_calculation", False
)
if skip_prev_logprobs:
print(
"Skipping prev_logprobs computation due to force_on_policy_ratio=True"
)
train_data["prev_logprobs"] = torch.zeros_like(
train_data["generation_logprobs"]
)
if not (skip_prev_logprobs and skip_reference_policy_logprobs):
print("▶ Preparing for logprob inference...", flush=True)
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Zero-filled prev_logprobs will produce misleading logs and diagnostics.

When force_on_policy_ratio=True, train_data["prev_logprobs"] is filled with zeros. The loss function correctly overrides this internally with curr_logprobs.detach(), but downstream code still reads the raw zeros:

  • Line 1886: log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() — logs zeros to JSONL.
  • Lines 1897–1909: The token_mult_prob_error visualization plots train_data["prev_logprobs"] (zeros) against generation_logprobs, producing a nonsensical chart.

Consider either (a) skipping these log entries when force_on_policy_ratio is on, or (b) back-filling train_data["prev_logprobs"] with the actual curr_logprobs returned from train_results (if available).

🤖 Prompt for AI Agents
In `@nemo_rl/algorithms/grpo.py` around lines 1579 - 1597, The code currently
zero-fills train_data["prev_logprobs"] when force_on_policy_ratio is True which
leads to misleading logs and plots (token_mult_prob_error); change the handling
so that when master_config["loss_fn"].get("force_on_policy_ratio", False) is
True you either (A) avoid emitting prev_logprobs into log_data and skip plotting
token_mult_prob_error, or (B) back-fill train_data["prev_logprobs"] with the
actual on-policy probabilities returned by the training step (e.g., use
train_results["curr_logprobs"] / .detach() if present) before any
logging/visualization; update the code paths around prev_logprobs,
train_results, log_data["prev_logprobs"], and token_mult_prob_error to implement
one of these behaviors.

@youngeunkwon0405
Copy link
Contributor

I have two questions related to logprob skipping.

  1. For the reference logprob, shouldn't we always skip the reference logprob when the loss_fn.reference_kl_penalty == 0? Why would we need an additional argument loss_fn.skip_reference_policy_logprobs_calculation?
  2. For the prev_logprob, I think in theory (please correct me if I am wrong), if the generation model and policy model are numerically identical, then generation_logprob == prev_logprob and we could always skip the prev_logprob calculation unless we need to report the mult_prob_error, gen_kl_error like metrics. I think having an argument like loss_fn.use_generation_logprob could be more intuitive (if it is true and seq_logprob_error_threshold==false, skip the prev_logprob calculation).
  3. For the force_on_policy_ratio, I think it was just asserting that the training batch size is equal to the generation batch size. So, it can use the current_logprob to skip prev_logprob calculation. I think we can just skip the prev_logprob calculation if train_bs==gen_bs and seq_logprob_error_threshold==false and no other features requiring log prob error stuff then skip the prev_logprob calculation.

@guyueh1 guyueh1 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 12, 2026
@guyueh1 guyueh1 requested a review from terrykong February 13, 2026 21:28
@guyueh1
Copy link
Contributor Author

guyueh1 commented Feb 13, 2026

I have two questions related to logprob skipping.

  1. For the reference logprob, shouldn't we always skip the reference logprob when the loss_fn.reference_kl_penalty == 0? Why would we need an additional argument loss_fn.skip_reference_policy_logprobs_calculation?
  2. For the prev_logprob, I think in theory (please correct me if I am wrong), if the generation model and policy model are numerically identical, then generation_logprob == prev_logprob and we could always skip the prev_logprob calculation unless we need to report the mult_prob_error, gen_kl_error like metrics. I think having an argument like loss_fn.use_generation_logprob could be more intuitive (if it is true and seq_logprob_error_threshold==false, skip the prev_logprob calculation).
  3. For the force_on_policy_ratio, I think it was just asserting that the training batch size is equal to the generation batch size. So, it can use the current_logprob to skip prev_logprob calculation. I think we can just skip the prev_logprob calculation if train_bs==gen_bs and seq_logprob_error_threshold==false and no other features requiring log prob error stuff then skip the prev_logprob calculation.
  1. I think the logic for skip_reference_policy_logprobs_calculation pre-exists in the codebase (
    skip_reference_policy_logprobs_calculation: NotRequired[bool]
    ) I just refactored it a bit and added the logic to async_grpo_train in this PR. But I think your point is valid, current logic is: user needs to explicitly specify skip_reference_policy_logprobs_calculation to skip the computation, and only when KL-penalty==0 it works, but the correct logic is: this is skipped whenever KL penalty is 0. @terrykong which one is better?

I am still reconciling about 2 but I do agree with 3, when certain conditions are met, we should skip logprob even if user doesn't explicitly specify force_on_policy_ratio, currently we are coupling a perf optimization with an algo feature that user wants to enforce and that's bad. I will do a revise based on 3.

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guyueh1 i believe this covers 012af2a51d22def1ff170ecb3594ffe538898ff9, but just double checking

Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guyueh1 and c759a4eba5d739650e6276ad0d4e3aa03985bb34

@guyueh1
Copy link
Contributor Author

guyueh1 commented Feb 14, 2026

012af2a51d22def1ff170ecb3594ffe538898ff9

No, but I think I can add this to my PR

c759a4eba5d739650e6276ad0d4e3aa03985bb34

Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests deepseek Related to deepseek 671b super-v3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants