From f5eb204aa1b33ccf4f47ebbd08f12527992e2a73 Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Thu, 12 Feb 2026 17:29:42 -0800 Subject: [PATCH] Remove do_not_average_loss; undo Megatron loss averaging in RL code Instead of passing do_not_average_loss=True to Megatron's forward_backward_func, we now let Megatron apply its default loss averaging (output_tensor *= cp_group_size; output_tensor /= num_microbatches) and undo it in forward_step_arbitrary_loss by applying the inverse (loss *= num_microbatches / cp_size). Changes: - common.py: rename cp_normalize -> undo_megatron_loss_averaging, add num_microbatches param, replace _div_by_cp_size with _undo_megatron_loss_averaging wrapper - megatron_policy_worker.py: remove do_not_average_loss=True, pass num_microbatches to forward_step partial - test_sequence_packing_gradients.py: update call to match new signature --- nemo_rl/models/megatron/common.py | 23 +++++++++++-------- .../policy/workers/megatron_policy_worker.py | 2 +- .../test_sequence_packing_gradients.py | 2 +- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 28af36b11b..e4da4c709b 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -51,7 +51,7 @@ def forward_step_arbitrary_loss( loss_fn: LossFunction, pack_sequences: bool = False, defer_fp32_logits: Optional[bool] = None, - cp_normalize: bool = True, + num_microbatches: int = 1, policy_cfg: Optional[dict] = None, ): """Forward training step with support for packed sequences and context parallelism. @@ -65,7 +65,8 @@ def forward_step_arbitrary_loss( loss_fn (LossFunction): Loss function to apply pack_sequences (bool): Whether to pack sequences for efficiency defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32 - cp_normalize (bool): Whether to normalize the loss by the cp_size + num_microbatches (int): Number of microbatches, needed to undo Megatron's default + loss averaging (which multiplies by cp_group_size and divides by num_microbatches). policy_cfg (Optional[dict]): Policy configuration containing generation parameters Notes on packed sequences with context parallelism (CP): @@ -142,15 +143,19 @@ def forward_step_arbitrary_loss( context_parallel_group=get_context_parallel_group(), ) - if cp_normalize: - cp_size = get_context_parallel_world_size() - orig_loss_fn_wrapped = loss_fn_wrapped + cp_size = get_context_parallel_world_size() + orig_loss_fn_wrapped = loss_fn_wrapped - def _div_by_cp_size(*args, **kwargs): - loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) - return loss / cp_size, metrics + def _undo_megatron_loss_averaging_and_cp_normalize(*args, **kwargs): + loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) + # Megatron's default behavior multiplies loss by cp_group_size and divides + # by num_microbatches. We undo that here (* num_microbatches / cp_size) + # so our loss function controls its own normalization. + # We also divide by cp_size for CP normalization. + loss = loss * num_microbatches / cp_size / cp_size + return loss, metrics - loss_fn_wrapped = _div_by_cp_size + loss_fn_wrapped = _undo_megatron_loss_averaging_and_cp_normalize return output_tensor, loss_fn_wrapped diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 798c4ea00a..9e5d789d19 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -392,6 +392,7 @@ def train( global_valid_toks, pack_sequences=self.cfg["sequence_packing"]["enabled"], defer_fp32_logits=self.defer_fp32_logits, + num_microbatches=num_microbatches, ), data_iterator=data_iterator, model=self.model, @@ -400,7 +401,6 @@ def train( micro_batch_size=mbs, decoder_seq_length=padded_seq_length, forward_only=eval_mode, - do_not_average_loss=True, ) # Empty unused memory. diff --git a/tests/unit/algorithms/test_sequence_packing_gradients.py b/tests/unit/algorithms/test_sequence_packing_gradients.py index 2982d1b0a8..b2d2852769 100644 --- a/tests/unit/algorithms/test_sequence_packing_gradients.py +++ b/tests/unit/algorithms/test_sequence_packing_gradients.py @@ -351,7 +351,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): model=MockModel(), loss_fn=base_loss_fn, pack_sequences=True, - cp_normalize=True, + num_microbatches=1, ) loss, metrics = wrapped_loss_fn(output_tensor)