Skip to content

Commit

Permalink
init and backward pass reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
saurabhkoshatwar committed Nov 10, 2024
1 parent bd5f976 commit 18fa8b7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
11 changes: 2 additions & 9 deletions src/liger_kernel/ops/tvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def tv_distance_forward_triton(p, q, reduction):
)

if reduction == _REDUCTION_MODE_BATCHMEAN.value:
return output_tensor.sum() / BT, grads
return output_tensor.sum() / BT, grads / BT
elif reduction == _REDUCTION_MODE_SUM.value:
return output_tensor.sum(dim=0), grads
elif reduction == _REDUCTION_MODE_MEAN.value:
return output_tensor.sum() / (BT * V), grads
return output_tensor.sum() / (BT * V), grads / (BT * V)
else:
return output_tensor, grads

Expand Down Expand Up @@ -155,7 +155,6 @@ def forward(
"""
loss, grads = tv_distance_forward_triton(p, q, reduction)
ctx.save_for_backward(grads)
ctx.reduction = reduction
return loss

@staticmethod
Expand All @@ -171,13 +170,7 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs.
"""
(grads,) = ctx.saved_tensors
BT, V = grads.shape

grads = tvd_backward_triton(grad_output, grads)

if ctx.reduction == "batchmean":
grads /= BT
elif ctx.reduction == "mean":
grads /= BT * V

return grads, None, None
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
from liger_kernel.transformers.monkey_patch import ( # noqa: F401
_apply_liger_kernel,
Expand Down

0 comments on commit 18fa8b7

Please sign in to comment.