From 18fa8b73f2ab0ed1372ca5cd34b02771321f7bdc Mon Sep 17 00:00:00 2001 From: Saurabh Date: Sat, 9 Nov 2024 17:29:46 -0800 Subject: [PATCH] init and backward pass reduction --- src/liger_kernel/ops/tvd.py | 11 ++--------- src/liger_kernel/transformers/__init__.py | 1 + 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/liger_kernel/ops/tvd.py b/src/liger_kernel/ops/tvd.py index 4c0df93ce..1099a3ec4 100644 --- a/src/liger_kernel/ops/tvd.py +++ b/src/liger_kernel/ops/tvd.py @@ -111,11 +111,11 @@ def tv_distance_forward_triton(p, q, reduction): ) if reduction == _REDUCTION_MODE_BATCHMEAN.value: - return output_tensor.sum() / BT, grads + return output_tensor.sum() / BT, grads / BT elif reduction == _REDUCTION_MODE_SUM.value: return output_tensor.sum(dim=0), grads elif reduction == _REDUCTION_MODE_MEAN.value: - return output_tensor.sum() / (BT * V), grads + return output_tensor.sum() / (BT * V), grads / (BT * V) else: return output_tensor, grads @@ -155,7 +155,6 @@ def forward( """ loss, grads = tv_distance_forward_triton(p, q, reduction) ctx.save_for_backward(grads) - ctx.reduction = reduction return loss @staticmethod @@ -171,13 +170,7 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: tuple[torch.Tensor, None, None]: The gradient of the loss with respect to the inputs. """ (grads,) = ctx.saved_tensors - BT, V = grads.shape grads = tvd_backward_triton(grad_output, grads) - if ctx.reduction == "batchmean": - grads /= BT - elif ctx.reduction == "mean": - grads /= BT * V - return grads, None, None diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index ffb8235cc..7a8d4feea 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -8,6 +8,7 @@ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401 from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401 from liger_kernel.transformers.jsd import LigerJSD # noqa: F401 +from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401 from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 from liger_kernel.transformers.monkey_patch import ( # noqa: F401 _apply_liger_kernel,