From d907ec0c09a0a998ecc1073afb195f1942a6c24f Mon Sep 17 00:00:00 2001 From: Shivam Sahni Date: Thu, 21 Nov 2024 21:25:14 -0800 Subject: [PATCH] add reference model logps to chunkedloss interface and fix dpo loss fn (#405) accomodate reference model logps in chunked loss interface and make dpo loss use reference model logps in its loss function ## Summary as title ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- src/liger_kernel/chunked_loss/dpo_loss.py | 40 ++++++- .../chunked_loss/fused_linear_preference.py | 106 +++++++++++++----- test/chunked_loss/test_dpo_loss.py | 70 ++++++++++-- test/utils.py | 46 +++++++- 4 files changed, 216 insertions(+), 46 deletions(-) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 601c15c3d..4ad870ff1 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -9,15 +9,31 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): @staticmethod - def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): + def preference_loss_fn( + chosen_logps, + rejected_logps, + ref_chosen_logps=None, + ref_rejected_logps=None, + beta=0.1, + ): """ Compute DPO loss (Direct Preference Optimization). Args: chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,). + ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,). beta (float): Weight for the direct preference loss. """ - logits_diff = beta * (chosen_logps - rejected_logps) + if ref_chosen_logps is None: + ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) + if ref_rejected_logps is None: + ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device) + + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + + logits_diff = beta * (chosen_logratios - rejected_logratios) losses = -F.logsigmoid(logits_diff) return losses.sum() @@ -28,10 +44,13 @@ def forward( weight, target, bias=None, + ref_weight=None, + ref_bias=None, ignore_index=-100, beta=0.1, compute_nll_loss=True, compiled=True, + use_ref_model=True, ): """ Fused linear layer with DPO (Direct Preference Optimization) loss. @@ -48,6 +67,9 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, compiled=compiled, + use_ref_model=use_ref_model, + ref_weight=ref_weight, + ref_bias=ref_bias, ) @staticmethod @@ -55,7 +77,7 @@ def backward(ctx, grad_output): # Get gradients for _input, weight, bias, and target from the base class grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] # Return these gradients, followed by None for the remaining inputs - return *grads, None, None, None, None + return *grads, None, None, None, None, None, None, None class LigerFusedLinearDPOLoss(torch.nn.Module): @@ -69,26 +91,36 @@ def __init__( beta: float = 0.1, compute_nll_loss: bool = True, compiled: bool = True, + use_ref_model: bool = False, ): """ Args: ignore_index (int): Index to ignore in the loss. beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute the NLL loss. + compiled (bool): Whether to use the torch compiled kernel. + use_ref_model (bool): Whether to use a reference model for the DPO loss. """ super().__init__() self.ignore_index = ignore_index self.beta = beta self.compute_nll_loss = compute_nll_loss self.compiled = compiled + self.use_ref_model = use_ref_model - def forward(self, lin_weight, _input, target, bias=None): + def forward( + self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None + ): return LigerFusedLinearDPOFunction.apply( _input, lin_weight, target, bias, + ref_weight, + ref_bias, self.ignore_index, self.beta, self.compute_nll_loss, self.compiled, + self.use_ref_model, ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 7dd2af160..ccf74ca04 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -18,6 +18,42 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): """ raise NotImplementedError("Preference loss function must be implemented.") + @staticmethod + def chunk_forward( + input_chunk, + weight, + target_chunk, + bias=None, + ignore_index=-100, + compute_nll_loss=True, + ): + len_chosen_chunk = target_chunk.shape[0] // 2 + logits_chunk = input_chunk @ weight.t() + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + + chosen_nll_loss = 0.0 + if compute_nll_loss: + chosen_nll_loss = F.nll_loss( + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( + -1 + ) + average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] + return chosen_logps, rejected_logps, chosen_nll_loss + @staticmethod def forward( ctx, @@ -32,6 +68,9 @@ def forward( beta=0.1, compute_nll_loss=True, compiled=True, + use_ref_model=False, + ref_weight=None, + ref_bias=None, **loss_kwargs, ): """ @@ -49,7 +88,11 @@ def forward( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute NLL loss. compiled (bool): Whether to use torch compile for chunk accumulation. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Other possible arguments that a loss function might need """ # TODO: Tune CHUNK_SIZE to fully utilize the GPU @@ -61,7 +104,6 @@ def forward( grad_bias = torch.zeros_like(bias) if bias is not None else None loss_acc = torch.zeros((), device=_input.device) - chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) loss_func_to_call = partial( LigerFusedLinearPreferenceBase._compute_loss, preference_loss_fn=loss_fn, @@ -70,6 +112,9 @@ def forward( beta=beta, compute_nll_loss=compute_nll_loss, full_target=target, + use_ref_model=use_ref_model, + ref_weight=ref_weight, + ref_bias=ref_bias, **loss_kwargs, ) @@ -101,6 +146,7 @@ def accumulate_chunk(input_chunk, target_chunk): accumulate_chunk = torch.compile(accumulate_chunk) len_chosen = target.shape[0] // 2 + chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) @@ -159,6 +205,9 @@ def _compute_loss( alpha=1.0, beta=0.1, compute_nll_loss=True, + use_ref_model=False, + ref_weight=None, + ref_bias=None, **loss_kwargs, ): """ @@ -173,38 +222,41 @@ def _compute_loss( ignore_index (int): Index to ignore for loss computation. alpha (float): Weight for the NLL loss. beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute NLL loss. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). loss_kwargs (dict): Additional arguments for the loss function. """ - len_chosen_chunk = target_chunk.shape[0] // 2 - - logits_chunk = input_chunk @ weight.t() # chunk_size x V - if bias is not None: - logits_chunk = logits_chunk + bias - log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) - - chosen_nll_loss = 0.0 - if compute_nll_loss: - chosen_nll_loss = F.nll_loss( - log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), - reduction="sum", + chosen_logps, rejected_logps, chosen_nll_loss = ( + LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, ignore_index=ignore_index, + compute_nll_loss=compute_nll_loss, ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) - - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( - -1 ) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] + if use_ref_model: + with torch.no_grad(): + ref_chosen_logps, ref_rejected_logps, _ = ( + LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, + ) + ) + loss_kwargs["ref_chosen_logps"] = ref_chosen_logps + loss_kwargs["ref_rejected_logps"] = ref_rejected_logps alignment_loss = preference_loss_fn( chosen_logps, rejected_logps, beta=beta, **loss_kwargs diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index e858626fd..2f9d1d94e 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -19,13 +19,19 @@ class HFDPOLoss(HFAlignmentLoss): Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py """ - def __init__(self, ignore_index: int = -100, beta: float = 0.1): - super().__init__(beta=beta, ignore_index=ignore_index) + def __init__( + self, ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True + ): + super().__init__( + beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model + ) def alignment_loss( self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, ): """Compute DPO loss for a batch of policy log probabilities. Args: @@ -36,7 +42,10 @@ def alignment_loss( The losses tensor contains the DPO loss for each example in the batch. """ # Derived from https://huggingface.co/papers/2305.18290 - logits_diff = self.beta * (policy_chosen_logps - policy_rejected_logps) + chosen_logratios = policy_chosen_logps - ref_chosen_logps + rejected_logratios = policy_rejected_logps - ref_rejected_logps + + logits_diff = self.beta * (chosen_logratios - rejected_logratios) losses = -F.logsigmoid(logits_diff) return losses @@ -48,6 +57,7 @@ def __init__( V: int, dtype: torch.dtype, bias: bool = False, + ref_bias: bool = False, ignore_index: int = -100, beta: float = 0.1, ): @@ -55,12 +65,17 @@ def __init__( self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) + self.ref_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=ref_bias, dtype=dtype + ) self.dpo_loss = HFDPOLoss( - ignore_index=ignore_index, beta=beta + ignore_index=ignore_index, beta=beta, use_ref_model=True ).get_batch_loss_metrics def forward(self, x, y): - return self.dpo_loss(self.lin.weight, x, y, self.lin.bias) + return self.dpo_loss( + self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias + ) class LigerLMHeadDPO(torch.nn.Module): @@ -70,6 +85,7 @@ def __init__( V: int, dtype: torch.dtype, bias: bool = False, + ref_bias: bool = False, ignore_index: int = -100, beta: float = 0.1, ): @@ -77,10 +93,17 @@ def __init__( self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) - self.dpo_loss = LigerFusedLinearDPOLoss(ignore_index=ignore_index, beta=beta) + self.ref_lin = torch.nn.Linear( + in_features=H, out_features=V, bias=ref_bias, dtype=dtype + ) + self.dpo_loss = LigerFusedLinearDPOLoss( + ignore_index=ignore_index, beta=beta, use_ref_model=True + ) def forward(self, x, y): - return self.dpo_loss(self.lin.weight, x, y, self.lin.bias) + return self.dpo_loss( + self.lin.weight, x, y, self.lin.bias, self.ref_lin.weight, self.ref_lin.bias + ) @pytest.mark.parametrize( @@ -98,8 +121,11 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) @pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) -def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): +def test_correctness( + B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta +): B = 2 * B # dpo loss requires B to be even torch_lm_head_dpo = TorchLMHeadDPO( @@ -107,6 +133,7 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, V=V, dtype=dtype, bias=bias, + ref_bias=ref_bias, ignore_index=ignore_index, beta=beta, ) @@ -115,6 +142,7 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, V=V, dtype=dtype, bias=bias, + ref_bias=ref_bias, ignore_index=ignore_index, beta=beta, ) @@ -122,11 +150,18 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, torch_lm_head_dpo.lin.weight.data = liger_lm_head_dpo.lin.weight.data = torch.randn( V, H, device="cuda", dtype=dtype ) + torch_lm_head_dpo.ref_lin.weight.data = liger_lm_head_dpo.ref_lin.weight.data = ( + torch.randn(V, H, device="cuda", dtype=dtype) + ) if bias: torch_lm_head_dpo.lin.bias.data = liger_lm_head_dpo.lin.bias.data = torch.randn( V, device="cuda", dtype=dtype ) + if ref_bias: + torch_lm_head_dpo.ref_lin.bias.data = liger_lm_head_dpo.ref_lin.bias.data = ( + torch.randn(V, device="cuda", dtype=dtype) + ) _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar input1 = _input.detach().clone().requires_grad_(True) @@ -186,7 +221,8 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, ], ) @pytest.mark.parametrize("bias", [True, False]) -def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): +@pytest.mark.parametrize("ref_bias", [True, False]) +def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): B = 2 * B _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar @@ -208,12 +244,24 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): weight1 = _weight.detach().clone().requires_grad_(True) weight2 = _weight.detach().clone().requires_grad_(True) + _ref_weight = torch.randn(V, H, device="cuda", dtype=dtype) + ref_weight1 = _ref_weight.detach().clone().requires_grad_(True) + ref_weight2 = _ref_weight.detach().clone().requires_grad_(True) + _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1 = LigerFusedLinearDPOFunction.apply(input1, weight1, target, bias1) - loss2 = liger_fused_linear_dpo(input2, weight2, target, bias2) + _ref_bias = torch.randn(V, device="cuda", dtype=dtype) if ref_bias else None + ref_bias1 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None + + loss1 = LigerFusedLinearDPOFunction.apply( + input1, weight1, target, bias1, ref_weight1, ref_bias1 + ) + loss2 = liger_fused_linear_dpo( + input2, weight2, target, bias2, ref_weight2, ref_bias2 + ) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index e65bbabdc..f209a0388 100644 --- a/test/utils.py +++ b/test/utils.py @@ -355,10 +355,17 @@ def revert_liger_kernel_to_phi3(model_config: MiniModelConfig): class HFAlignmentLoss: - def __init__(self, alpha: float = 1.0, beta: float = 0.1, ignore_index: int = -100): + def __init__( + self, + alpha: float = 1.0, + beta: float = 0.1, + ignore_index: int = -100, + use_ref_model: bool = False, + ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index + self.use_ref_model = use_ref_model @abstractmethod def alignment_loss(self): @@ -400,6 +407,27 @@ def get_batch_logps( else: return (per_token_logps * loss_mask).sum(-1) + def get_ref_logps( + self, + _input: torch.FloatTensor, + ref_weight: torch.FloatTensor, + target: torch.LongTensor, + ref_bias: torch.FloatTensor, + average_log_prob: bool = True, + ): + """Compute the log probabilities of the given labels under the given reference model.""" + + ref_logits = _input @ ref_weight.t() + if ref_bias is not None: + ref_logits = ref_logits + ref_bias + ref_all_logps = self.get_batch_logps( + ref_logits, target, average_log_prob=average_log_prob + ) + return ( + ref_all_logps[: _input.shape[0] // 2], + ref_all_logps[_input.shape[0] // 2 :], + ) + def concatenated_forward( self, _input: torch.FloatTensor, @@ -462,7 +490,8 @@ def get_batch_loss_metrics( _input: torch.FloatTensor, target: torch.LongTensor, bias: torch.FloatTensor = None, - alpha: float = 1.0, + ref_weight: torch.FloatTensor = None, + ref_bias: torch.FloatTensor = None, average_log_prob: bool = True, ): """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" @@ -478,7 +507,16 @@ def get_batch_loss_metrics( policy_nll_loss, ) = forward_output[:5] - losses = self.alignment_loss(policy_chosen_logps, policy_rejected_logps) + loss_kwargs = {} + if self.use_ref_model: + ref_chosen_logps, ref_rejected_logps = self.get_ref_logps( + _input, ref_weight, target, ref_bias, average_log_prob + ) + loss_kwargs["ref_chosen_logps"] = ref_chosen_logps + loss_kwargs["ref_rejected_logps"] = ref_rejected_logps + losses = self.alignment_loss( + policy_chosen_logps, policy_rejected_logps, **loss_kwargs + ) # full loss - loss = policy_nll_loss * alpha - losses.mean() + loss = policy_nll_loss * self.alpha - losses.mean() return loss