Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions nemo_rl/models/megatron/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward_step_arbitrary_loss(
loss_fn: LossFunction,
pack_sequences: bool = False,
defer_fp32_logits: Optional[bool] = None,
cp_normalize: bool = True,
num_microbatches: int = 1,
policy_cfg: Optional[dict] = None,
):
"""Forward training step with support for packed sequences and context parallelism.
Expand All @@ -65,7 +65,8 @@ def forward_step_arbitrary_loss(
loss_fn (LossFunction): Loss function to apply
pack_sequences (bool): Whether to pack sequences for efficiency
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
cp_normalize (bool): Whether to normalize the loss by the cp_size
num_microbatches (int): Number of microbatches, needed to undo Megatron's default
loss averaging (which multiplies by cp_group_size and divides by num_microbatches).
policy_cfg (Optional[dict]): Policy configuration containing generation parameters

Notes on packed sequences with context parallelism (CP):
Expand Down Expand Up @@ -142,15 +143,19 @@ def forward_step_arbitrary_loss(
context_parallel_group=get_context_parallel_group(),
)

if cp_normalize:
cp_size = get_context_parallel_world_size()
orig_loss_fn_wrapped = loss_fn_wrapped
cp_size = get_context_parallel_world_size()
orig_loss_fn_wrapped = loss_fn_wrapped

def _div_by_cp_size(*args, **kwargs):
loss, metrics = orig_loss_fn_wrapped(*args, **kwargs)
return loss / cp_size, metrics
def _undo_megatron_loss_averaging_and_cp_normalize(*args, **kwargs):
loss, metrics = orig_loss_fn_wrapped(*args, **kwargs)
# Megatron's default behavior multiplies loss by cp_group_size and divides
# by num_microbatches. We undo that here (* num_microbatches / cp_size)
# so our loss function controls its own normalization.
# We also divide by cp_size for CP normalization.
loss = loss * num_microbatches / cp_size / cp_size
return loss, metrics

loss_fn_wrapped = _div_by_cp_size
loss_fn_wrapped = _undo_megatron_loss_averaging_and_cp_normalize

return output_tensor, loss_fn_wrapped

Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def train(
global_valid_toks,
pack_sequences=self.cfg["sequence_packing"]["enabled"],
defer_fp32_logits=self.defer_fp32_logits,
num_microbatches=num_microbatches,
),
data_iterator=data_iterator,
model=self.model,
Expand All @@ -400,7 +401,6 @@ def train(
micro_batch_size=mbs,
decoder_seq_length=padded_seq_length,
forward_only=eval_mode,
do_not_average_loss=True,
)

# Empty unused memory.
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/algorithms/test_sequence_packing_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
model=MockModel(),
loss_fn=base_loss_fn,
pack_sequences=True,
cp_normalize=True,
num_microbatches=1,
)
loss, metrics = wrapped_loss_fn(output_tensor)

Expand Down
Loading