From bb7ae5e3a04d014dbc17d11e0730878b0cce20c1 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sat, 7 Feb 2026 00:10:49 -0800 Subject: [PATCH 1/4] update ClippedPGLossFn, NLLLoss, DPOLossFn Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss_functions.py | 135 ++------------------------- nemo_rl/distributed/model_utils.py | 44 +++++++++ nemo_rl/models/automodel/train.py | 11 ++- 3 files changed, 60 insertions(+), 130 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index c61cb5f0ce..7bcfb8cf1f 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -25,9 +25,7 @@ ChunkedDistributedGatherLogprob, _get_tokens_on_this_cp_rank, allgather_cp_sharded_tensor, - from_parallel_logits_to_logprobs, gather_logits_at_global_indices, - get_logprobs_from_vocab_parallel_logits, ) Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -185,13 +183,10 @@ def __init__(self, cfg: ClippedPGLossConfig): def __call__( self, - next_token_logits: Tensor, + curr_logprobs: Tensor, data: BatchedDataDict[ClippedPGLossDataDict], global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" token_mask = data["token_mask"][:, 1:] @@ -201,7 +196,6 @@ def __call__( generation_logprobs = data["generation_logprobs"][:, 1:] if self.reference_policy_kl_penalty != 0: reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] - seq_index = data.get("seq_index", None) mask = token_mask * sample_mask.unsqueeze(-1) @@ -269,39 +263,6 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - next_token_logits = next_token_logits.to(torch.float32) - - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - curr_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - curr_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_token_logits_wo_last = next_token_logits[ - :, :-1 - ] # Remove last position's logits - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits_wo_last, dim=-1 - ) - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - curr_logprobs = next_token_logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) - # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: if self.use_on_policy_kl_approximation: @@ -539,13 +500,10 @@ class NLLLoss(LossFunction): def __call__( self, - next_token_logits: Tensor, + token_logprobs: Tensor, data: BatchedDataDict[Any], global_valid_seqs: Tensor | None, global_valid_toks: Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, dpo_loss: bool = False, dpo_average_log_probs: bool = False, ) -> tuple[torch.Tensor, dict[str, Any]]: @@ -554,39 +512,6 @@ def __call__( token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - seq_index = data.get("seq_index", None) - - next_token_logits = next_token_logits.to(torch.float32) - - # Gather the logprobs for the actual next tokens - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) if dpo_loss: ## shape: [batch_size] @@ -800,50 +725,15 @@ def __init__(self, cfg: DPOLossConfig): def _dpo_loss( self, - next_token_logits: Tensor, + token_logprobs: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] - seq_index = data.get("seq_index", None) - - next_token_logits = next_token_logits.to(torch.float32) - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) ref_logprobs = data["reference_policy_logprobs"][:, :-1] - diff = (token_logprobs - ref_logprobs) * token_mask rewards = diff.sum(-1) @@ -857,13 +747,10 @@ def _dpo_loss( # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLoss) def __call__( # type: ignore self, - next_token_logits: Tensor, + token_logprobs: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: @@ -871,13 +758,10 @@ def __call__( # type: ignore "global_valid_toks must be provided for SFT loss" ) sft_loss, _ = self.sft_loss( - next_token_logits, + token_logprobs, data, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, ) @@ -893,14 +777,7 @@ def __call__( # type: ignore accuracy, rewards_chosen_mean, rewards_rejected_mean, - ) = self._dpo_loss( - next_token_logits, - data, - global_valid_seqs, - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, - ) + ) = self._dpo_loss(token_logprobs, data, global_valid_seqs) dpo_loss = ( self.sft_loss_weight * sft_loss_chosen diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index fb17ee1661..b012777279 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -825,6 +825,50 @@ def get_logprobs_from_vocab_parallel_logits( ) +def get_logprobs_from_logits( + input_ids: torch.Tensor, + next_token_logits: torch.Tensor, + seq_index: Optional[torch.Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +): + """Computes log probabilities from logits.""" + next_token_logits = next_token_logits.to(torch.float32) + + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + input_ids, + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + ) + # slice off to the correct length to remove potential CP padding + logprobs = logprobs[:, : input_ids.shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, input_ids, seq_index=seq_index + ) + else: + # Remove last position's logits + next_token_logits_wo_last = next_token_logits[:, :-1] + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits_wo_last, dim=-1 + ) + next_tokens = input_ids[:, 1:].cuda() # Skip first token + logprobs = next_token_logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + return logprobs + + @torch.no_grad() def distributed_vocab_topk( vocab_parallel_logits: torch.Tensor, diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index acbfec711e..6c7892e0ef 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -37,6 +37,7 @@ from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, distributed_vocab_topk, + get_logprobs_from_logits, get_logprobs_from_vocab_parallel_logits, ) from nemo_rl.models.automodel.data import ProcessedInputs, ProcessedMicrobatch @@ -513,6 +514,14 @@ def __call__( logits, self.device_mesh, self.cp_mesh, sequence_dim ) + # Compute logprobs from logits + logprobs = get_logprobs_from_logits( + input_ids=mb["input_ids"], + next_token_logits=logits, + seq_index=mb.get("seq_index", None), + ) + del logits + # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn_ = SequencePackingLossWrapper( @@ -524,7 +533,7 @@ def __call__( loss_fn_ = self.loss_fn loss, loss_metrics = loss_fn_( - logits, + logprobs, mb, global_valid_seqs, global_valid_toks, From c48508e20a6f332c9b3f8c24632aead5f755525a Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Feb 2026 06:14:39 -0800 Subject: [PATCH 2/4] fix seq packing Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss_functions.py | 25 ++++++++------ nemo_rl/models/automodel/train.py | 51 ++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 7bcfb8cf1f..9498d1a918 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, NotRequired, Optional, TypedDict, TypeVar +from typing import Any, Callable, NotRequired, Optional, TypedDict, TypeVar import torch import torch.distributed @@ -802,12 +802,20 @@ class SequencePackingLossWrapper: def __init__( self, loss_fn: LossFunction, + prepare_fn: Callable[Any, Any], cu_seqlens_q: Tensor, cu_seqlens_q_padded: Optional[Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ): self.loss_fn = loss_fn + self.prepare_fn = prepare_fn self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_q_padded = cu_seqlens_q_padded + self.vocab_parallel_rank = vocab_parallel_rank + self.vocab_parallel_group = vocab_parallel_group + self.context_parallel_group = context_parallel_group def __call__( self, @@ -815,9 +823,6 @@ def __call__( data: BatchedDataDict[Any], global_valid_seqs: Tensor | None, global_valid_toks: Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[Tensor, dict[str, Any]]: """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" unpadded_cu_seqlens = self.cu_seqlens_q @@ -851,8 +856,8 @@ def __call__( # get next_token_logits cp_size = ( 1 - if context_parallel_group is None - else torch.distributed.get_world_size(context_parallel_group) + if self.context_parallel_group is None + else torch.distributed.get_world_size(self.context_parallel_group) ) logit_start = seq_start // cp_size logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size @@ -861,14 +866,14 @@ def __call__( 1, logit_start, logit_length ) + # prepare data for loss function + loss_fn_args = self.prepare_fn(next_token_logits_slice, unpadded_seq_data) + loss, metrics = self.loss_fn( - next_token_logits_slice, + *loss_fn_args, unpadded_seq_data, global_valid_seqs, global_valid_toks, - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, ) loss_accum += loss for k, v in metrics.items(): diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index 6c7892e0ef..cb8cc1c939 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -505,6 +505,12 @@ def __call__( Returns: Tuple of (loss, metrics) """ + from nemo_rl.algorithms.loss_functions import ( + ClippedPGLossFn, + DPOLossFn, + NLLLoss, + ) + # Handle CP redistribution if self.cp_size > 1: _, mb = prepare_data_for_cp( @@ -514,30 +520,45 @@ def __call__( logits, self.device_mesh, self.cp_mesh, sequence_dim ) - # Compute logprobs from logits - logprobs = get_logprobs_from_logits( - input_ids=mb["input_ids"], - next_token_logits=logits, - seq_index=mb.get("seq_index", None), - ) - del logits + # Prepare data for loss function + def prepare_for_loss_fn( + logits: torch.Tensor, mb: BatchedDataDict[Any] + ) -> tuple[Any]: + if isinstance(self.loss_fn, (ClippedPGLossFn, NLLLoss, DPOLossFn)): + logprobs = get_logprobs_from_logits( + input_ids=mb["input_ids"], + next_token_logits=logits, + seq_index=mb.get("seq_index", None), + ) + + loss_fn_args = (logprobs,) + + # TODO: PreferenceLoss, DistillationLossFn + + return loss_fn_args # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn_ = SequencePackingLossWrapper( loss_fn=self.loss_fn, + prepare_fn=prepare_for_loss_fn, cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, ) + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, + ) else: - loss_fn_ = self.loss_fn - - loss, loss_metrics = loss_fn_( - logprobs, - mb, - global_valid_seqs, - global_valid_toks, - ) + loss_fn_args = prepare_for_loss_fn(logits, mb) + loss, loss_metrics = self.loss_fn( + *loss_fn_args, + mb, + global_valid_seqs, + global_valid_toks, + ) return loss, loss_metrics From 4c7c2ea8da478fde66cfacd7f0f82fde5ca11ba6 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Feb 2026 22:04:30 -0800 Subject: [PATCH 3/4] update unit test for sft/rl/dpo and add value check for distillation Signed-off-by: Yuki Huang --- tests/unit/algorithms/test_loss_functions.py | 245 +++++++------------ 1 file changed, 92 insertions(+), 153 deletions(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 5be0e69c80..951d2cc226 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -26,6 +26,7 @@ ) from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import get_logprobs_from_logits basic_pg_loss_test_config: ClippedPGLossConfig = { "ratio_clip_min": 0.2, @@ -91,8 +92,9 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) + token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) loss, metrics_dict = loss_fn( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -116,8 +118,9 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) + token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) loss, metrics_dict = loss_fn( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -151,8 +154,9 @@ def test_dpo_loss(): } ) + token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) loss, metrics_dict = loss_fn( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -185,7 +189,7 @@ def test_dpo_loss(): expected_preference_loss = -torch.nn.functional.logsigmoid(torch.tensor(0.0)) assert torch.isclose( loss_fn_with_sft( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -260,16 +264,17 @@ def test_dpo_loss_varying_sequence_lengths(): "sample_mask": sample_mask, } ) + token_logprobs = get_logprobs_from_logits(data["input_ids"], next_token_logits) # Compute loss loss, metrics = dpo_loss_fn_no_avg( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(sample_mask), global_valid_toks=torch.sum(sample_mask.unsqueeze(-1) * token_mask), ) loss_avg, metrics_avg = dpo_loss_fn_avg( - next_token_logits, + token_logprobs, data, global_valid_seqs=torch.sum(sample_mask), global_valid_toks=torch.sum(sample_mask.unsqueeze(-1) * token_mask), @@ -322,8 +327,11 @@ def test_dpo_sft_matches_nll_loss(): # Compute NLL loss nll_loss_fn = NLLLoss() + token_logprobs = get_logprobs_from_logits( + sft_data["input_ids"], next_token_logits[::2] + ) nll_loss, nll_metrics = nll_loss_fn( - next_token_logits[::2], + token_logprobs, sft_data, global_valid_seqs=None, global_valid_toks=torch.sum( @@ -341,8 +349,9 @@ def test_dpo_sft_matches_nll_loss(): "sft_average_log_probs": False, } ) + token_logprobs = get_logprobs_from_logits(dpo_data["input_ids"], next_token_logits) dpo_loss, dpo_metrics = dpo_loss_fn( - next_token_logits, + token_logprobs, dpo_data, global_valid_seqs=torch.sum(dpo_data["sample_mask"]), global_valid_toks=torch.sum( @@ -504,9 +513,10 @@ def test_clipped_pg_loss_ppo_clipping(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -551,9 +561,10 @@ def test_clipped_pg_loss_reinforce_mode(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -596,9 +607,10 @@ def test_clipped_pg_loss_force_on_policy_ratio(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, metrics = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -706,9 +718,10 @@ def test_clipped_pg_loss_kl_penalty(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -734,6 +747,8 @@ def test_clipped_pg_loss_masking(): ) # Need some realistic-ish logits and logprobs for masking test dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) + current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) + # Ensure logprobs used by the loss fn make sense relative to advantages data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 data["reference_policy_logprobs"] = ( @@ -749,7 +764,7 @@ def test_clipped_pg_loss_masking(): # --- Test 1: Token Mask --- # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample loss_default, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -765,7 +780,7 @@ def test_clipped_pg_loss_masking(): ) loss_token_masked, _ = loss_fn( - dummy_logits, + current_logprobs, data_mod_token, global_valid_seqs=torch.sum(data_mod_token["sample_mask"]), global_valid_toks=torch.sum( @@ -784,7 +799,7 @@ def test_clipped_pg_loss_masking(): ) # Ignore item 1 loss_sample_masked, _ = loss_fn( - dummy_logits, + current_logprobs, data_mod_sample, global_valid_seqs=torch.sum(data_mod_sample["sample_mask"]), global_valid_toks=torch.sum( @@ -805,8 +820,11 @@ def test_clipped_pg_loss_masking(): data_only_b0 = BatchedDataDict(data_only_b0_dict) logits_only_b0 = dummy_logits[0:1] + current_logprobs_only_b0 = get_logprobs_from_logits( + data_only_b0["input_ids"], logits_only_b0 + ) loss_only_b0, _ = loss_fn( - logits_only_b0, + current_logprobs_only_b0, data_only_b0, global_valid_seqs=torch.sum(data_only_b0["sample_mask"]), global_valid_toks=torch.sum( @@ -826,6 +844,7 @@ def test_clipped_pg_loss_zero_mask(): data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) # Need dummy logits dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) + current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 @@ -835,7 +854,7 @@ def test_clipped_pg_loss_zero_mask(): data["token_mask"] = torch.zeros_like(data["token_mask"]) loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -980,9 +999,10 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1112,9 +1132,10 @@ def test_clipped_pg_loss_on_policy_truncated_importance_sampling( dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1233,9 +1254,10 @@ def test_clipped_pg_loss_dual_clip(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1282,8 +1304,10 @@ def test_clipped_pg_loss_entropy(): dummy_logits = _create_exact_logits( curr_lp_masked, data["input_ids"], batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(data["input_ids"], dummy_logits) + _, metrics = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1365,9 +1389,10 @@ def test_clipped_pg_loss_gspo(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1463,9 +1488,10 @@ def test_clipped_pg_loss_gspo_batch_size_2(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1564,9 +1590,10 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + current_logprobs = get_logprobs_from_logits(input_ids, dummy_logits) actual_loss, _ = loss_fn( - dummy_logits, + current_logprobs, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), @@ -1581,6 +1608,10 @@ def setup_distillation_test_data(batch_size=2, seq_len=4, vocab_size=8, topk=64) device = "cuda" + # Set seed for reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + # Create input data input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) input_lengths = torch.tensor([seq_len] * batch_size, device=device) @@ -1608,15 +1639,17 @@ def setup_distillation_test_data(batch_size=2, seq_len=4, vocab_size=8, topk=64) return data, student_logits -def test_distillation_loss_forward_kl(): - """Test forward KL divergence loss calculation.""" +@pytest.mark.parametrize("kl_type", ["forward", "reverse", "mixed"]) +@pytest.mark.parametrize("zero_outside_topk", [True, False]) +def test_distillation_loss_different_settings(kl_type, zero_outside_topk): + """Test different distillation loss settings.""" data, student_logits = setup_distillation_test_data() loss_fn = DistillationLossFn( { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": False, + "kl_type": kl_type, + "mixed_kl_weight": 0.3, + "zero_outside_topk": zero_outside_topk, } ) @@ -1629,25 +1662,38 @@ def test_distillation_loss_forward_kl(): ), ) - # Verify loss is a scalar tensor - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) + # Verify loss + if zero_outside_topk: + if kl_type == "forward": + assert torch.allclose(loss, torch.tensor(-0.9636520743370056)) + elif kl_type == "reverse": + assert torch.allclose(loss, torch.tensor(-490.5150451660156)) + elif kl_type == "mixed": + assert torch.allclose(loss, torch.tensor(-343.6496276855469)) + else: + if kl_type == "forward": + assert torch.allclose(loss, torch.tensor(0.5783048868179321)) + elif kl_type == "reverse": + assert torch.allclose(loss, torch.tensor(0.5811167359352112)) + elif kl_type == "mixed": + assert torch.allclose(loss, torch.tensor(0.5802732110023499)) # Verify metrics dictionary assert isinstance(metrics, dict) assert "loss" in metrics -def test_distillation_loss_reverse_kl(): - """Test reverse KL divergence loss calculation.""" - data, student_logits = setup_distillation_test_data() +@pytest.mark.parametrize("k", [1, 32, 64, 1000000]) +@pytest.mark.parametrize("zero_outside_topk", [True, False]) +def test_distillation_loss_topk_filtering(k, zero_outside_topk): + """Test top-k filtering functionality with various k values.""" + data, student_logits = setup_distillation_test_data(topk=k) loss_fn = DistillationLossFn( { - "kl_type": "reverse", + "kl_type": "forward", "mixed_kl_weight": 0.5, - "zero_outside_topk": False, + "zero_outside_topk": zero_outside_topk, } ) @@ -1660,86 +1706,19 @@ def test_distillation_loss_reverse_kl(): ), ) - # Verify loss is a scalar tensor + # Verify loss is calculated correctly with top-k filtering assert loss.dim() == 0 assert not torch.isnan(loss) assert not torch.isinf(loss) - # Verify metrics dictionary - assert isinstance(metrics, dict) - assert "loss" in metrics - + # For k=1, we expect only the top-1 token to be considered + if k == 1: + assert isinstance(loss, torch.Tensor) -def test_distillation_loss_mixed_kl(): - """Test mixed KL divergence loss calculation.""" - data, student_logits = setup_distillation_test_data() - - mixed_kl_weight = 0.3 - loss_fn = DistillationLossFn( - { - "kl_type": "mixed", - "mixed_kl_weight": mixed_kl_weight, - "zero_outside_topk": False, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is a scalar tensor - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # Verify metrics dictionary - assert isinstance(metrics, dict) - assert "loss" in metrics - - -def test_distillation_loss_topk_filtering(): - """Test top-k filtering functionality with various k values.""" - # Test with different k values (excluding k=0 which should be invalid) - k_values = [1, 32, 64, 1000000] # Valid k values - - for k in k_values: - data, student_logits = setup_distillation_test_data(topk=k) - - loss_fn = DistillationLossFn( - { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": False, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is calculated correctly with top-k filtering - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # For k=1, we expect only the top-1 token to be considered - if k == 1: - assert isinstance(loss, torch.Tensor) - - # For large k values, we expect normal behavior - if k >= 32: - assert isinstance(loss, torch.Tensor) - assert loss.item() != 0.0 # Should have some meaningful loss + # For large k values, we expect normal behavior + if k >= 32: + assert isinstance(loss, torch.Tensor) + assert loss.item() != 0.0 # Should have some meaningful loss def test_distillation_loss_invalid_k_zero(): @@ -1767,46 +1746,6 @@ def test_distillation_loss_invalid_k_zero(): ) -def test_distillation_loss_zero_outside_topk(): - """Test zeroing outside top-k functionality with various k values.""" - # Test with different k values for zero_outside_topk (excluding k=0 which should be invalid) - k_values = [1, 32, 64, 1000000] # Valid k values - - for k in k_values: - data, student_logits = setup_distillation_test_data(topk=k) - - loss_fn = DistillationLossFn( - { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": True, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is calculated correctly with zeroing - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # For k=1, only top-1 token should remain non-zero - if k == 1: - assert isinstance(loss, torch.Tensor) - - # For large k values, most tokens should remain non-zero - if k >= 32: - assert isinstance(loss, torch.Tensor) - assert loss.item() != 0.0 # Should have some meaningful loss - - def test_distillation_loss_gradient_flow(): """Test gradient flow in distillation loss function.""" data, student_logits = setup_distillation_test_data() From 54e12834ea4612f9e94f905dc096f8dea6759043 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 9 Feb 2026 23:48:39 -0800 Subject: [PATCH 4/4] update PreferenceLoss and DistillationLossFn Signed-off-by: Yuki Huang --- nemo_rl/algorithms/loss_functions.py | 166 +----------------- nemo_rl/distributed/model_utils.py | 167 +++++++++++++++++++ nemo_rl/models/automodel/train.py | 25 ++- tests/unit/algorithms/test_loss_functions.py | 96 +++++++++-- 4 files changed, 275 insertions(+), 179 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 9498d1a918..f615df1e01 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -20,13 +20,6 @@ from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import ( - ChunkedDistributedEntropy, - ChunkedDistributedGatherLogprob, - _get_tokens_on_this_cp_rank, - allgather_cp_sharded_tensor, - gather_logits_at_global_indices, -) Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -932,165 +925,14 @@ def __init__(self, cfg: DistillationLossConfig): def __call__( self, - next_token_logits: torch.Tensor, + student_topk_logprobs: torch.Tensor, + teacher_topk_logprobs: torch.Tensor, + H_all: torch.Tensor | None, data: DistillationLossDataDict, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute distillation loss between teacher and student logits.""" - # Basic shapes - input_ids = data["input_ids"] - batch_size = input_ids.shape[0] - - # CP support: get CP group and size - cp_group = context_parallel_group - cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) - - # Ensure float32 for stability (match other losses) - next_token_logits = next_token_logits.to(torch.float32) - per_token_kl = None - # Preferred truncated-KL path: teacher provides top-k support per position - teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k] - teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k] - - if teacher_topk_indices.shape[-1] <= 0: - raise ValueError( - f"topk must be positive, got {teacher_topk_indices.shape[-1]}. " - "topk=0 is not supported as it would result in empty tensor operations." - ) - - # Determine processing path and setup variables - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - V_local = int(next_token_logits.shape[-1]) - vocab_start_index = vocab_parallel_rank * V_local - vocab_end_index = (vocab_parallel_rank + 1) * V_local - parallel_group = vocab_parallel_group - logits_tensor = next_token_logits - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - device_mesh = next_token_logits.device_mesh - tp_group = device_mesh.get_group("tp") - tp_rank = tp_group.rank() - local_student_logits = next_token_logits.to_local() - V_local = int(local_student_logits.shape[-1]) - vocab_start_index = tp_rank * V_local - vocab_end_index = (tp_rank + 1) * V_local - parallel_group = tp_group - logits_tensor = local_student_logits - teacher_topk_indices = teacher_topk_indices.to(local_student_logits.device) - # For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment - if ( - device_mesh.mesh_dim_names is not None - and "cp" in device_mesh.mesh_dim_names - ): - cp_group = device_mesh.get_group("cp") - cp_size = cp_group.size() - else: - cp_group = None - cp_size = 1 - else: - parallel_group = None - logits_tensor = next_token_logits - - # Process based on zero_outside_topk setting - if self.zero_outside_topk and parallel_group is not None: - # Distributed processing with chunking - indices_local = teacher_topk_indices - pad_len = 0 - if cp_size > 1: - pad_len = logits_tensor.shape[1] * cp_size - indices_local.shape[1] - if pad_len > 0: - indices_local = torch.nn.functional.pad( - indices_local, (0, 0, 0, pad_len), value=0 - ) - cp_rank = torch.distributed.get_rank(cp_group) - indices_local = _get_tokens_on_this_cp_rank( - indices_local, cp_rank, cp_size, seq_dim=1 - ) - - S_local = int(logits_tensor.shape[1]) - chunk_size = max(1, min(S_local, 1024)) - student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore - logits_tensor, - indices_local, - vocab_start_index, - vocab_end_index, - chunk_size, - parallel_group, - False, - ) - - if self.kl_type != "forward": - H_all = ChunkedDistributedEntropy.apply( # type: ignore - logits_tensor, - chunk_size, - parallel_group, - False, - ) - - if cp_size > 1: - student_topk_logprobs = allgather_cp_sharded_tensor( - student_topk_logprobs, cp_group, seq_dim=1 - ) - if self.kl_type != "forward": - H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1) - if pad_len > 0: - student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :] - if self.kl_type != "forward": - H_all = H_all[:, :-pad_len] - elif self.zero_outside_topk: - # Non-distributed processing - student_logprobs = torch.nn.functional.log_softmax(logits_tensor, dim=-1) - student_topk_logprobs = student_logprobs.gather( - dim=-1, index=teacher_topk_indices.to(student_logprobs.device) - ) - if self.kl_type != "forward": - H_all = (student_logprobs.exp() * student_logprobs).sum(-1) - else: - # Gather logits at global indices - if (parallel_group is not None) or (cp_size > 1): - student_topk_logits = gather_logits_at_global_indices( - logits_tensor, - teacher_topk_indices, - tp_group=parallel_group, - cp_group=cp_group, - vocab_start_index=( - vocab_start_index if parallel_group is not None else 0 - ), - vocab_end_index=( - vocab_end_index - if parallel_group is not None - else int(logits_tensor.shape[-1]) - ), - ) - else: - student_topk_logits = logits_tensor.gather( - dim=-1, index=teacher_topk_indices.to(logits_tensor.device) - ) - student_topk_logprobs = torch.nn.functional.log_softmax( - student_topk_logits, dim=-1 - ) - - # Move teacher tensors to the same device/dtype as student_topk_logits - teacher_topk_logits = teacher_topk_logits.to( - student_topk_logprobs.device, dtype=student_topk_logprobs.dtype - ) - teacher_topk_logprobs = torch.nn.functional.log_softmax( - teacher_topk_logits, dim=-1 - ) - - # Single point of next-token alignment after TP/CP processing - teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :] - student_topk_logprobs = student_topk_logprobs[:, :-1, :] - if self.zero_outside_topk and self.kl_type != "forward": - # Align H_all with next-token prediction - H_all = H_all[:, :-1] - student_probs = student_topk_logprobs.exp() # [B, S-1, k] teacher_probs = teacher_topk_logprobs.exp() # [B, S-1, k] @@ -1143,7 +985,7 @@ def __call__( metrics = { "loss": float(kl_loss.item()) if kl_loss.ndim == 0 else kl_loss, - "num_valid_samples": int(batch_size), + "num_valid_samples": data["input_ids"].shape[0], } return kl_loss, metrics diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index b012777279..50ffb1a28d 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -1026,6 +1026,173 @@ def gather_logits_at_global_indices( return gathered_logits +def get_distilllation_topk_logprobs_from_logits( + student_logits: torch.Tensor, + teacher_topk_logits: torch.Tensor, + teacher_topk_indices: torch.Tensor, + zero_outside_topk: bool, + calculate_entropy: bool, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +): + """Compute top-k log probabilities from logits.""" + if teacher_topk_indices.shape[-1] <= 0: + raise ValueError( + f"topk must be positive, got {teacher_topk_indices.shape[-1]}. " + "topk=0 is not supported as it would result in empty tensor operations." + ) + + # Ensure float32 for stability + student_logits = student_logits.to(torch.float32) + # Move teacher topk indices to the same device as student logits + teacher_topk_indices = teacher_topk_indices.to(student_logits.device) + + # CP support: get CP group and size + cp_group = context_parallel_group + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + + # Process based on the student logits type + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + student_logits = student_logits + parallel_group = vocab_parallel_group + + V_local = int(student_logits.shape[-1]) + vocab_start_index = vocab_parallel_rank * V_local + vocab_end_index = (vocab_parallel_rank + 1) * V_local + + elif isinstance(student_logits, torch.distributed.tensor.DTensor): + device_mesh = student_logits.device_mesh + tp_group = device_mesh.get_group("tp") + + student_logits = student_logits.to_local() + parallel_group = tp_group + + tp_rank = tp_group.rank() + V_local = int(student_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + # For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment + if ( + device_mesh.mesh_dim_names is not None + and "cp" in device_mesh.mesh_dim_names + ): + cp_group = device_mesh.get_group("cp") + cp_size = cp_group.size() + else: + cp_group = None + cp_size = 1 + + else: + student_logits = student_logits + parallel_group = None + + # Process based on the zero_outside_topk setting + H_all = None + if zero_outside_topk: + # Distributed processing + if parallel_group is not None: + indices_local = teacher_topk_indices + pad_len = 0 + + if cp_size > 1: + pad_len = student_logits.shape[1] * cp_size - indices_local.shape[1] + if pad_len > 0: + indices_local = torch.nn.functional.pad( + indices_local, (0, 0, 0, pad_len), value=0 + ) + cp_rank = torch.distributed.get_rank(cp_group) + indices_local = _get_tokens_on_this_cp_rank( + indices_local, cp_rank, cp_size, seq_dim=1 + ) + + seq_len_local = int(student_logits.shape[1]) + chunk_size = max(1, min(seq_len_local, 1024)) + student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore + student_logits, + indices_local, + vocab_start_index, + vocab_end_index, + chunk_size, + parallel_group, + False, + ) + + if calculate_entropy: + H_all = ChunkedDistributedEntropy.apply( # type: ignore + student_logits, + chunk_size, + parallel_group, + False, + ) + + if cp_size > 1: + student_topk_logprobs = allgather_cp_sharded_tensor( + student_topk_logprobs, cp_group, seq_dim=1 + ) + if calculate_entropy: + H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1) + if pad_len > 0: + student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :] + if calculate_entropy: + H_all = H_all[:, :-pad_len] + + # Non-distributed processing + else: + student_logprobs = torch.nn.functional.log_softmax(student_logits, dim=-1) + student_topk_logprobs = student_logprobs.gather( + dim=-1, index=teacher_topk_indices + ) + + if calculate_entropy: + H_all = (student_logprobs.exp() * student_logprobs).sum(-1) + + else: + # Distributed processing + if parallel_group is not None or cp_size > 1: + if parallel_group is None: + vocab_start_index = 0 + vocab_end_index = int(student_logits.shape[-1]) + + student_topk_logits = gather_logits_at_global_indices( + student_logits, + teacher_topk_indices, + tp_group=parallel_group, + cp_group=cp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + + # Non-distributed processing + else: + student_topk_logits = student_logits.gather( + dim=-1, index=teacher_topk_indices + ) + + student_topk_logprobs = torch.nn.functional.log_softmax( + student_topk_logits, dim=-1 + ) + + # Move teacher tensors to the same device/dtype as student_topk_logits + teacher_topk_logits = teacher_topk_logits.to( + student_topk_logprobs.device, dtype=student_topk_logprobs.dtype + ) + teacher_topk_logprobs = torch.nn.functional.log_softmax(teacher_topk_logits, dim=-1) + + # Single point of next-token alignment after TP/CP processing + teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :] + student_topk_logprobs = student_topk_logprobs[:, :-1, :] + + if calculate_entropy: + H_all = H_all[:, :-1] + + return student_topk_logprobs, teacher_topk_logprobs, H_all + + class ChunkedDistributedEntropy(torch.autograd.Function): """Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index cb8cc1c939..d7a311bbf1 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -37,6 +37,7 @@ from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, distributed_vocab_topk, + get_distilllation_topk_logprobs_from_logits, get_logprobs_from_logits, get_logprobs_from_vocab_parallel_logits, ) @@ -507,8 +508,10 @@ def __call__( """ from nemo_rl.algorithms.loss_functions import ( ClippedPGLossFn, + DistillationLossFn, DPOLossFn, NLLLoss, + PreferenceLoss, ) # Handle CP redistribution @@ -533,7 +536,27 @@ def prepare_for_loss_fn( loss_fn_args = (logprobs,) - # TODO: PreferenceLoss, DistillationLossFn + elif isinstance(self.loss_fn, PreferenceLoss): + loss_fn_args = (logits,) + + elif isinstance(self.loss_fn, DistillationLossFn): + calculate_entropy = ( + self.loss_fn.zero_outside_topk and self.loss_fn.kl_type != "forward" + ) + student_topk_logprobs, teacher_topk_logprobs, H_all = ( + get_distilllation_topk_logprobs_from_logits( + student_logits=logits, + teacher_topk_logits=mb["teacher_topk_logits"], + teacher_topk_indices=mb["teacher_topk_indices"], + zero_outside_topk=self.loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + ) + + loss_fn_args = (student_topk_logprobs, teacher_topk_logprobs, H_all) + + else: + raise ValueError(f"Unknown loss function type: {type(self.loss_fn)}") return loss_fn_args diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 951d2cc226..1dc80cf7b5 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -26,7 +26,10 @@ ) from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import get_logprobs_from_logits +from nemo_rl.distributed.model_utils import ( + get_distilllation_topk_logprobs_from_logits, + get_logprobs_from_logits, +) basic_pg_loss_test_config: ClippedPGLossConfig = { "ratio_clip_min": 0.2, @@ -1653,8 +1656,17 @@ def test_distillation_loss_different_settings(kl_type, zero_outside_topk): } ) + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, metrics = loss_fn( - student_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1697,8 +1709,17 @@ def test_distillation_loss_topk_filtering(k, zero_outside_topk): } ) - loss, metrics = loss_fn( - student_logits, + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + + loss, _ = loss_fn( + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1736,13 +1757,13 @@ def test_distillation_loss_invalid_k_zero(): # This should raise a ValueError for k=0 with pytest.raises(ValueError, match="topk must be positive"): - loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + _ = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, ) @@ -1761,8 +1782,17 @@ def test_distillation_loss_gradient_flow(): } ) + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, _ = loss_fn( - student_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1794,8 +1824,17 @@ def test_distillation_loss_edge_cases(): # Test with all-zero logits zero_logits = torch.zeros_like(student_logits) + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=zero_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, _ = loss_fn( - zero_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1807,8 +1846,16 @@ def test_distillation_loss_edge_cases(): # Test with very large logits large_logits = torch.ones_like(student_logits) * 100.0 + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=large_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, _ = loss_fn( - large_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1820,8 +1867,16 @@ def test_distillation_loss_edge_cases(): # Test with very small logits small_logits = torch.ones_like(student_logits) * -100.0 + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=small_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, _ = loss_fn( - small_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( @@ -1869,8 +1924,17 @@ def test_distillation_loss_fn_call(): } ) + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + loss_fn_args = get_distilllation_topk_logprobs_from_logits( + student_logits=student_logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + ) + loss, metrics = loss_fn( - student_logits, + *loss_fn_args, data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(