feat: skip logprob and reference logprob computation under certain conditions#1891
feat: skip logprob and reference logprob computation under certain conditions#1891guyueh1 wants to merge 9 commits intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
Signed-off-by: Guyue Huang <guyueh@nvidia.com>
📝 WalkthroughWalkthroughThis PR introduces configurable flags to optimize GRPO training by enabling optional skipping of logprob computations. Configuration files are updated with Changes
Sequence DiagramsequenceDiagram
participant TrainingLoop as GRPO Training Loop
participant LogprobCalc as Logprob Calculation
participant LossFunc as Loss Function
participant DataPrep as Data Preparation
TrainingLoop->>DataPrep: Prepare training data
alt skip_reference_policy_logprobs_calculation == false
DataPrep->>LogprobCalc: Compute reference_policy_logprobs
LogprobCalc-->>DataPrep: reference_policy_logprobs
else
DataPrep-->>DataPrep: Skip reference logprob computation
end
alt skip_prev_logprobs == false
DataPrep->>LogprobCalc: Compute prev_logprobs
LogprobCalc-->>DataPrep: prev_logprobs
else
DataPrep-->>DataPrep: Set prev_logprobs to zeros
end
DataPrep->>LossFunc: Pass data with logprobs
alt force_on_policy_ratio == true
LossFunc->>LossFunc: Compute curr_logprobs on-policy
LossFunc->>LossFunc: Override prev_logprobs with curr_logprobs
else
LossFunc->>LossFunc: Use provided prev_logprobs
end
LossFunc-->>TrainingLoop: Compute loss
TrainingLoop->>TrainingLoop: Backpropagate
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 1579-1597: The code currently zero-fills
train_data["prev_logprobs"] when force_on_policy_ratio is True which leads to
misleading logs and plots (token_mult_prob_error); change the handling so that
when master_config["loss_fn"].get("force_on_policy_ratio", False) is True you
either (A) avoid emitting prev_logprobs into log_data and skip plotting
token_mult_prob_error, or (B) back-fill train_data["prev_logprobs"] with the
actual on-policy probabilities returned by the training step (e.g., use
train_results["curr_logprobs"] / .detach() if present) before any
logging/visualization; update the code paths around prev_logprobs,
train_results, log_data["prev_logprobs"], and token_mult_prob_error to implement
one of these behaviors.
🧹 Nitpick comments (3)
nemo_rl/algorithms/grpo.py (3)
337-337: Explain whyNRL_IGNORE_TP_ACCURACY_CHECKis needed whenforce_on_policy_ratiois enabled.Setting a global environment variable as a side effect of a config flag is opaque. Consider adding a comment explaining why the TP accuracy check must be disabled here, so future maintainers understand the coupling.
1602-1621: Minor:logprob_datais allocated even when both logprob computations are skipped.When both
skip_prev_logprobsandskip_reference_policy_logprobsareTrue, thelogprob_datadict on lines 1602–1608 is created but never read. This is lightweight (just references, no tensor copies), so it's not a real concern — just a nit for clarity.
2601-2633: Duplicated skip-logic betweengrpo_trainandasync_grpo_train.Lines 2601–2633 are nearly identical to lines 1579–1621 in
grpo_train. Consider extracting the flag resolution and conditionalprepare_for_lp_inference/ logprob gating into a shared helper to keep both paths in sync and reduce maintenance burden.
| force_on_policy_ratio = master_config["loss_fn"].get( | ||
| "force_on_policy_ratio", False | ||
| ) | ||
| skip_prev_logprobs = force_on_policy_ratio | ||
| skip_reference_policy_logprobs = master_config["grpo"].get( | ||
| "skip_reference_policy_logprobs_calculation", False | ||
| ) | ||
| if skip_prev_logprobs: | ||
| print( | ||
| "Skipping prev_logprobs computation due to force_on_policy_ratio=True" | ||
| ) | ||
| train_data["prev_logprobs"] = torch.zeros_like( | ||
| train_data["generation_logprobs"] | ||
| ) | ||
| if not (skip_prev_logprobs and skip_reference_policy_logprobs): | ||
| print("▶ Preparing for logprob inference...", flush=True) | ||
| with timer.time("logprob_inference_prep"): | ||
| policy.prepare_for_lp_inference() | ||
|
|
There was a problem hiding this comment.
Zero-filled prev_logprobs will produce misleading logs and diagnostics.
When force_on_policy_ratio=True, train_data["prev_logprobs"] is filled with zeros. The loss function correctly overrides this internally with curr_logprobs.detach(), but downstream code still reads the raw zeros:
- Line 1886:
log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist()— logs zeros to JSONL. - Lines 1897–1909: The
token_mult_prob_errorvisualization plotstrain_data["prev_logprobs"](zeros) againstgeneration_logprobs, producing a nonsensical chart.
Consider either (a) skipping these log entries when force_on_policy_ratio is on, or (b) back-filling train_data["prev_logprobs"] with the actual curr_logprobs returned from train_results (if available).
🤖 Prompt for AI Agents
In `@nemo_rl/algorithms/grpo.py` around lines 1579 - 1597, The code currently
zero-fills train_data["prev_logprobs"] when force_on_policy_ratio is True which
leads to misleading logs and plots (token_mult_prob_error); change the handling
so that when master_config["loss_fn"].get("force_on_policy_ratio", False) is
True you either (A) avoid emitting prev_logprobs into log_data and skip plotting
token_mult_prob_error, or (B) back-fill train_data["prev_logprobs"] with the
actual on-policy probabilities returned by the training step (e.g., use
train_results["curr_logprobs"] / .detach() if present) before any
logging/visualization; update the code paths around prev_logprobs,
train_results, log_data["prev_logprobs"], and token_mult_prob_error to implement
one of these behaviors.
|
I have two questions related to logprob skipping.
|
I am still reconciling about 2 but I do agree with 3, when certain conditions are met, we should skip logprob even if user doesn't explicitly specify |
No, but I think I can add this to my PR
Yes |
What does this PR do ?
Skip logprob and reference logprob computation under certain conditions:
loss_fn.skip_reference_policy_logprobs_calculation=true, skip reference logprob. The requirement isloss_fn.reference_kl_penalty == 0which will be checked wheneverskip_reference_policy_logprobs_calculationis true.loss_fn.force_on_policy_ratio=true, skip logprob computation. The requirement is rollout batch size == train global batch size, which will be checked wheneverforce_on_policy_ratiois true.Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes