Skip to content

Remove do_not_average_loss; undo Megatron loss averaging in RL code#1940

Open
yaoyu-33 wants to merge 1 commit intomainfrom
yuya/undo-megatron-loss-averaging
Open

Remove do_not_average_loss; undo Megatron loss averaging in RL code#1940
yaoyu-33 wants to merge 1 commit intomainfrom
yuya/undo-megatron-loss-averaging

Conversation

@yaoyu-33
Copy link
Contributor

@yaoyu-33 yaoyu-33 commented Feb 13, 2026

Summary

Instead of passing do_not_average_loss=True to Megatron's forward_backward_func, we now let Megatron apply its default loss averaging (output_tensor *= cp_group_size; output_tensor /= num_microbatches) and undo it in forward_step_arbitrary_loss by applying the inverse (loss *= num_microbatches / cp_size).

This removes our dependency on the upstream do_not_average_loss option in Megatron-LM (ref: PR 2951).

Changes

  • nemo_rl/models/megatron/common.py: Rename cp_normalizeundo_megatron_loss_averaging, add num_microbatches param, replace _div_by_cp_size wrapper with _undo_megatron_loss_averaging that applies loss * num_microbatches / cp_size
  • nemo_rl/models/policy/workers/megatron_policy_worker.py: Remove do_not_average_loss=True from forward_backward_func call, pass num_microbatches to forward_step partial
  • tests/unit/algorithms/test_sequence_packing_gradients.py: Update call to forward_step_arbitrary_loss to use new parameter names

Summary by CodeRabbit

  • Refactor
    • Changed loss computation parameter from a boolean normalization flag to an explicit microbatch count, improving clarity in training configuration and microbatch handling during loss computation.

Instead of passing do_not_average_loss=True to Megatron's
forward_backward_func, we now let Megatron apply its default loss
averaging (output_tensor *= cp_group_size; output_tensor /= num_microbatches)
and undo it in forward_step_arbitrary_loss by applying the inverse
(loss *= num_microbatches / cp_size).

Changes:
- common.py: rename cp_normalize -> undo_megatron_loss_averaging, add
  num_microbatches param, replace _div_by_cp_size with
  _undo_megatron_loss_averaging wrapper
- megatron_policy_worker.py: remove do_not_average_loss=True, pass
  num_microbatches to forward_step partial
- test_sequence_packing_gradients.py: update call to match new signature
@yaoyu-33 yaoyu-33 requested review from a team as code owners February 13, 2026 01:30
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 13, 2026

📝 Walkthrough

Walkthrough

The pull request refactors loss normalization handling in Megatron-based RL training by replacing a boolean cp_normalize parameter with an explicit num_microbatches integer parameter. The loss normalization wrapper is updated to deterministically apply loss scaling based on microbatch count and context-parallel size, with corresponding changes to caller sites.

Changes

Cohort / File(s) Summary
Loss Normalization API Update
nemo_rl/models/megatron/common.py
Replaced cp_normalize bool parameter with num_microbatches int (default 1) in forward_step_arbitrary_loss. Introduced _undo_megatron_loss_averaging_and_cp_normalize wrapper that unconditionally applies loss scaling: loss * num_microbatches / cp_size / cp_size. Removed conditional normalization logic.
Worker Integration
nemo_rl/models/policy/workers/megatron_policy_worker.py
Updated forward_backward_func call to forward num_microbatches argument to forward_step_func. Removed do_not_average_loss=True parameter.
Test Updates
tests/unit/algorithms/test_sequence_packing_gradients.py
Changed loss wrapper call parameter from cp_normalize=True to num_microbatches=1 to align with updated API.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR contains major loss normalization logic changes affecting numerics/convergence but includes no test results, regression testing information, or verification demonstrating no regression. Add test results and convergence validation data to PR description demonstrating that training behavior is unchanged and numerical correctness is maintained.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly reflects the main changes: removing do_not_average_loss and implementing logic to undo Megatron's loss averaging in the RL codebase.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yuya/undo-megatron-loss-averaging

No actionable comments were generated in the recent review. 🎉

🧹 Recent nitpick comments
nemo_rl/models/megatron/common.py (2)

146-158: Loss correction logic is mathematically sound.

The two divisions by cp_size are correct: one undoes Megatron's *= cp_size, the other applies CP normalization, matching the prior behavior when do_not_average_loss=True + _div_by_cp_size was used.

One minor note: this wrapper assumes it runs inside forward_backward_func (which applies *= cp_size and /= num_microbatches). If called standalone (as in test 3 of test_sequence_packing_gradients.py), the undo is applied to averaging that never happened. This is benign when cp_size == 1 and num_microbatches == 1, but a brief docstring note that this function is designed to be used as forward_step_func in forward_backward_func would help prevent misuse.


54-54: Default num_microbatches=1 could silently produce wrong loss scaling.

If a caller forgets to pass the actual microbatch count, the default of 1 will silently produce an incorrectly-scaled loss when the real microbatch count is > 1. Consider removing the default so callers are forced to be explicit, or at least documenting the risk.

Proposed fix: remove default value
-    num_microbatches: int = 1,
+    num_microbatches: int,
tests/unit/algorithms/test_sequence_packing_gradients.py (1)

354-354: API call updated correctly, but test 3 lacks assertions for cp_size > 1.

The parameter change from cp_normalize=True to num_microbatches=1 matches the new API. However, since forward_step_arbitrary_loss is called directly here (not through forward_backward_func), the loss wrapper undoes Megatron averaging that was never applied. This is a no-op for cp_size=1 but scales incorrectly for cp_size=2. Test 3 only prints gradient stats without any assert_close, so this doesn't cause failures — but it means test 3 doesn't actually validate correctness for cp_size=2.

Consider adding an assertion or at least a comment noting that test 3 is intentionally diagnostic-only.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant