Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cross entropy kernel: inplace modification of input #343

Open
mgrabban opened this issue Nov 4, 2024 · 2 comments
Open

cross entropy kernel: inplace modification of input #343

mgrabban opened this issue Nov 4, 2024 · 2 comments

Comments

@mgrabban
Copy link
Collaborator

mgrabban commented Nov 4, 2024

Hello,
Regarding inplace modification of PyTorch tensors, there are already multiple (#254, #262, #272) issues. I would also like to point out that according to PyTorch docs for mark_dirty():

Every tensor that’s been modified in-place in a call to forward() should be given to this function, to ensure correctness of our checks

So can you please look into it and come up with a resolution? :-)

Now I noticed another thing.
Since currently gradient is being calculated during forward pass, during backward pass, if grad of loss w.r.t. cross entropy output is 1.0, no calculation is done and only the grad calculated during the forward pass is returned. However, just doing that seems to make backward pass take almost the same time as the forward pass. I do not understand why that would be so.

I profiled both forward and backward passes using Triton's proton profiler with the following script:

import torch
# from torch.nn import CrossEntropyLoss as TorchCrossEntropyLoss

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss

import triton.profiler as proton


device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif torch.xpu.is_available():
    device = 'xpu'
else:
    print('No accelerator available, running on CPU')


def run():
    # Create Tensors to hold input and outputs.
    with proton.scope("0-init"):
        V = 131072
        B = 8
        T = 2048

        _input = torch.randn(B * T, V, requires_grad=True, device=device)
        target = torch.randint(V, (B * T, 1), device=device).squeeze(1)

        # layer = TorchCrossEntropyLoss()
        layer = LigerCrossEntropyLoss()

    with proton.scope("1-warmup"):
        for _ in range(10):
            y = layer(_input, target)
            y.backward(retain_graph=True)

    for _ in range(100):
        # Forward pass:
        with proton.scope("2-forward"):
            y = layer(_input, target)

        # Backward pass:
        with proton.scope("3-backward"):
            y.backward(retain_graph=True)


func = proton.profile(run, name="cross_entropy", context='shadow')
func()

# Write out the profile
# Visualize using `proton-viewer -m time/ms ./cross_entropy.hatchet`
proton.finalize()

Below is the time breakdown among different kernels:

3361.863 ROOT
├─ 8.295 0-init
│  ├─ 8.292 _ZN2at6native54_GLOBAL__N__d8ceb000_21_DistributionNormal_cu_0c5b6e8543distribution_elementwise_grid_stride_kernelIfLi4EZNS0_9templates4cuda20normal_and_transformIffPNS_17CUDAGeneratorImplEZZZNS4_13normal_kernelIS7_EEvRKNS_10TensorBaseEddT_ENKUlvE_clEvENKUlvE0_clEvEUlfE_EEvRNS_18TensorIteratorBaseET1_T2_EUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIff6float4S7_SM_SF_EEvSH_SJ_RKT3_T4_EUlifE_EEviNS_15PhiloxCudaStateESI_SJ_
│  └─ 0.003 _ZN2at6native60_GLOBAL__N__2b2da9da_27_DistributionRandomKernel_cu_f88cee4443distribution_elementwise_grid_stride_kernelIjLi4EZZZNS0_9templates4cuda21random_from_to_kernelIPNS_17CUDAGeneratorImplEEEvRNS_18TensorIteratorBaseEmlT_ENKUlvE_clEvENKUlvE2_clEvEUlP24curandStatePhilox4_32_10E0_ZNS1_27distribution_nullary_kernelIlj5uint4S7_SF_ZZZS5_IS7_EvS9_mlSA_ENKSB_clEvENKSC_clEvEUljE_EEvS9_T2_RKT3_T4_EUlijE_EEviNS_15PhiloxCudaStateET1_SJ_
├─ 291.116 1-warmup
│  ├─ 0.087 _ZN2at6native13reduce_kernelILi512ELi1ENS0_8ReduceOpIfNS0_14func_wrapper_tIfZNS0_11sum_functorIfffEclERNS_14TensorIteratorEEUlffE_EEjfLi4EEEEEvT1_
│  ├─ 0.058 _ZN2at6native13reduce_kernelILi512ELi1ENS0_8ReduceOpIlNS0_14func_wrapper_tIlZNS0_11sum_functorIlllEclERNS_14TensorIteratorEEUlllE_EEjlLi4EEEEEvT1_
│  ├─ 0.041 _ZN2at6native27unrolled_elementwise_kernelIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE2_clEvEUllE_NS_6detail5ArrayIPcLi2EEE23TrivialOffsetCalculatorILi1EjESC_NS0_6memory12LoadWithCastILi1EEENSD_13StoreWithCastILi1EEEEEviT_T0_T1_T2_T3_T4_
│  ├─ 0.039 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_11FillFunctorIfEENS_6detail5ArrayIPcLi1EEEEEviT0_T1_
│  ├─ 0.027 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_13AUnaryFunctorIllbNS0_51_GLOBAL__N__28ce311f_18_CompareEQKernel_cu_d8008c9616CompareEqFunctorIlEEEENS_6detail5ArrayIPcLi2EEEEEviT0_T1_
│  ├─ 0.015 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_13BinaryFunctorIffbNS0_51_GLOBAL__N__28ce311f_18_CompareEQKernel_cu_d8008c9616CompareEqFunctorIfEEEENS_6detail5ArrayIPcLi3EEEEEviT0_T1_
│  ├─ 136.551 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_15CUDAFunctor_addIfEENS_6detail5ArrayIPcLi3EEEEEviT0_T1_
│  └─ 154.297 liger_cross_entropy_kernel
├─ 1545.133 2-forward
│  ├─ 0.878 _ZN2at6native13reduce_kernelILi512ELi1ENS0_8ReduceOpIfNS0_14func_wrapper_tIfZNS0_11sum_functorIfffEclERNS_14TensorIteratorEEUlffE_EEjfLi4EEEEEvT1_
│  ├─ 0.569 _ZN2at6native13reduce_kernelILi512ELi1ENS0_8ReduceOpIlNS0_14func_wrapper_tIlZNS0_11sum_functorIlllEclERNS_14TensorIteratorEEUlllE_EEjlLi4EEEEEvT1_
│  ├─ 0.401 _ZN2at6native27unrolled_elementwise_kernelIZZZNS0_23direct_copy_kernel_cudaERNS_18TensorIteratorBaseEENKUlvE1_clEvENKUlvE2_clEvEUllE_NS_6detail5ArrayIPcLi2EEE23TrivialOffsetCalculatorILi1EjESC_NS0_6memory12LoadWithCastILi1EEENSD_13StoreWithCastILi1EEEEEviT_T0_T1_T2_T3_T4_
│  ├─ 0.207 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_11FillFunctorIfEENS_6detail5ArrayIPcLi1EEEEEviT0_T1_
│  ├─ 0.275 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_13AUnaryFunctorIllbNS0_51_GLOBAL__N__28ce311f_18_CompareEQKernel_cu_d8008c9616CompareEqFunctorIlEEEENS_6detail5ArrayIPcLi2EEEEEviT0_T1_
│  └─ 1542.804 liger_cross_entropy_kernel
└─ 1517.319 3-backward
   ├─ 0.171 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_11FillFunctorIfEENS_6detail5ArrayIPcLi1EEEEEviT0_T1_
   ├─ 0.148 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_13BinaryFunctorIffbNS0_51_GLOBAL__N__28ce311f_18_CompareEQKernel_cu_d8008c9616CompareEqFunctorIfEEEENS_6detail5ArrayIPcLi3EEEEEviT0_T1_
   └─ 1517.000 _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_15CUDAFunctor_addIfEENS_6detail5ArrayIPcLi3EEEEEviT0_T1_

Notice that during backward, almost 100% time is spent in the last kernel, namely _ZN2at6native29vectorized_elementwise_kernelILi4ENS0_15CUDAFunctor_addIfEENS_6detail5ArrayIPcLi3EEEEEviT0_T1_. I have checked this kernel takes similar amount of time in the PyTorch/huggingface version of the calculation.
So are we gaining much from the inplace calculation? It already seems illegal and there is hardly any backward pass speedup even with it

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Nov 4, 2024

Thanks! mark_dirty() can definitely be a potential resolution to inplace calculation warning.

Notice that during backward, almost 100% time is spent in the last kernel, namely ZN2at6native29vectorized_elementwise_kernelILi4ENS0_15CUDAFunctor_addIfEENS_6detail5ArrayIPcLi3EEEEEviT0_T1. I have checked this kernel takes similar amount of time in the PyTorch/huggingface version of the calculation.

I guess this kernel is gradient accumulation kernel, not gradient calculation? I suppose there is another calculation kernel called in the torch version? Below is the trace of liger ce's backward:
image

I think the main reason liger's ce does forward and backward in a single kernel + inplace is for fused linear ce usage, which has to acquire gradients and handles previous layer's gradients as well, so inplace calculation won't be an issue for it.

@mgrabban
Copy link
Collaborator Author

mgrabban commented Nov 4, 2024

Thanks! mark_dirty() can definitely be a potential resolution to inplace calculation warning.

Notice that during backward, almost 100% time is spent in the last kernel, namely ZN2at6native29vectorized_elementwise_kernelILi4ENS0_15CUDAFunctor_addIfEENS_6detail5ArrayIPcLi3EEEEEviT0_T1. I have checked this kernel takes similar amount of time in the PyTorch/huggingface version of the calculation.

I guess this kernel is gradient accumulation kernel, not gradient calculation? I suppose there is another calculation kernel called in the torch version? Below is the trace of liger ce's backward:

Yes. I was surprised that this accumulation kernel could take that much time. Then I was thinking maybe the grad could be saved into _input.grad directly in the forward kernel to avoid this accumulation. But I am not sure if this can be done though.

I think the main reason liger's ce does forward and backward in a single kernel + inplace is for fused linear ce usage, which has to acquire gradients and handles previous layer's gradients as well, so inplace calculation won't be an issue for it.

This makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants