-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cross entropy kernel: inplace modification of input #343
Comments
Thanks! mark_dirty() can definitely be a potential resolution to inplace calculation warning.
I guess this kernel is gradient accumulation kernel, not gradient calculation? I suppose there is another calculation kernel called in the torch version? Below is the trace of liger ce's backward: I think the main reason liger's ce does forward and backward in a single kernel + inplace is for fused linear ce usage, which has to acquire gradients and handles previous layer's gradients as well, so inplace calculation won't be an issue for it. |
Yes. I was surprised that this accumulation kernel could take that much time. Then I was thinking maybe the grad could be saved into _input.grad directly in the forward kernel to avoid this accumulation. But I am not sure if this can be done though.
This makes sense. |
Hello,
Regarding inplace modification of PyTorch tensors, there are already multiple (#254, #262, #272) issues. I would also like to point out that according to PyTorch docs for mark_dirty():
So can you please look into it and come up with a resolution? :-)
Now I noticed another thing.
Since currently gradient is being calculated during forward pass, during backward pass, if grad of loss w.r.t. cross entropy output is 1.0, no calculation is done and only the grad calculated during the forward pass is returned. However, just doing that seems to make backward pass take almost the same time as the forward pass. I do not understand why that would be so.
I profiled both forward and backward passes using Triton's proton profiler with the following script:
Below is the time breakdown among different kernels:
Notice that during backward, almost 100% time is spent in the last kernel, namely
_ZN2at6native29vectorized_elementwise_kernelILi4ENS0_15CUDAFunctor_addIfEENS_6detail5ArrayIPcLi3EEEEEviT0_T1_
. I have checked this kernel takes similar amount of time in the PyTorch/huggingface version of the calculation.So are we gaining much from the inplace calculation? It already seems illegal and there is hardly any backward pass speedup even with it
The text was updated successfully, but these errors were encountered: