Skip to content

Commit

Permalink
add reference model logps to chunkedloss interface and fix dpo loss fn (
Browse files Browse the repository at this point in the history
#405)

accomodate reference model logps in chunked loss interface and make dpo
loss use reference model logps in its loss function
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
as title
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
shivam15s authored Nov 22, 2024
1 parent 317ff43 commit d907ec0
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 46 deletions.
40 changes: 36 additions & 4 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,31 @@
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
def preference_loss_fn(
chosen_logps,
rejected_logps,
ref_chosen_logps=None,
ref_rejected_logps=None,
beta=0.1,
):
"""
Compute DPO loss (Direct Preference Optimization).
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the direct preference loss.
"""
logits_diff = beta * (chosen_logps - rejected_logps)
if ref_chosen_logps is None:
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
if ref_rejected_logps is None:
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)

chosen_logratios = chosen_logps - ref_chosen_logps
rejected_logratios = rejected_logps - ref_rejected_logps

logits_diff = beta * (chosen_logratios - rejected_logratios)
losses = -F.logsigmoid(logits_diff)
return losses.sum()

Expand All @@ -28,10 +44,13 @@ def forward(
weight,
target,
bias=None,
ref_weight=None,
ref_bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compiled=True,
use_ref_model=True,
):
"""
Fused linear layer with DPO (Direct Preference Optimization) loss.
Expand All @@ -48,14 +67,17 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
)

@staticmethod
def backward(ctx, grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None
return *grads, None, None, None, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
Expand All @@ -69,26 +91,36 @@ def __init__(
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
use_ref_model: bool = False,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute the NLL loss.
compiled (bool): Whether to use the torch compiled kernel.
use_ref_model (bool): Whether to use a reference model for the DPO loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.use_ref_model = use_ref_model

def forward(self, lin_weight, _input, target, bias=None):
def forward(
self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
):
return LigerFusedLinearDPOFunction.apply(
_input,
lin_weight,
target,
bias,
ref_weight,
ref_bias,
self.ignore_index,
self.beta,
self.compute_nll_loss,
self.compiled,
self.use_ref_model,
)
106 changes: 79 additions & 27 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,42 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
"""
raise NotImplementedError("Preference loss function must be implemented.")

@staticmethod
def chunk_forward(
input_chunk,
weight,
target_chunk,
bias=None,
ignore_index=-100,
compute_nll_loss=True,
):
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
-1
)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]
return chosen_logps, rejected_logps, chosen_nll_loss

@staticmethod
def forward(
ctx,
Expand All @@ -32,6 +68,9 @@ def forward(
beta=0.1,
compute_nll_loss=True,
compiled=True,
use_ref_model=False,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
):
"""
Expand All @@ -49,7 +88,11 @@ def forward(
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute NLL loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
Expand All @@ -61,7 +104,6 @@ def forward(
grad_bias = torch.zeros_like(bias) if bias is not None else None
loss_acc = torch.zeros((), device=_input.device)

chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
loss_func_to_call = partial(
LigerFusedLinearPreferenceBase._compute_loss,
preference_loss_fn=loss_fn,
Expand All @@ -70,6 +112,9 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
full_target=target,
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
**loss_kwargs,
)

Expand Down Expand Up @@ -101,6 +146,7 @@ def accumulate_chunk(input_chunk, target_chunk):
accumulate_chunk = torch.compile(accumulate_chunk)

len_chosen = target.shape[0] // 2
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
Expand Down Expand Up @@ -159,6 +205,9 @@ def _compute_loss(
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
use_ref_model=False,
ref_weight=None,
ref_bias=None,
**loss_kwargs,
):
"""
Expand All @@ -173,38 +222,41 @@ def _compute_loss(
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
compute_nll_loss (bool): Whether to compute NLL loss.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
loss_kwargs (dict): Additional arguments for the loss function.
"""
len_chosen_chunk = target_chunk.shape[0] // 2

logits_chunk = input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
if compute_nll_loss:
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
chosen_logps, rejected_logps, chosen_nll_loss = (
LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
weight,
target_chunk,
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
-1
)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]
if use_ref_model:
with torch.no_grad():
ref_chosen_logps, ref_rejected_logps, _ = (
LigerFusedLinearPreferenceBase.chunk_forward(
input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False,
)
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps

alignment_loss = preference_loss_fn(
chosen_logps, rejected_logps, beta=beta, **loss_kwargs
Expand Down
Loading

0 comments on commit d907ec0

Please sign in to comment.