-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pytorch] Update context parallel softmax lse correction func #716
Conversation
Hi @Kite0011, thank you for your contribution! Could you please sign your commits (as outlined here: https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst)? |
Hi @Kite0011 , really appreciate the fix. The math calculation is correct. The problem is that the function should do in-place update. softmax_lse is initialized at here softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double). Every time you call flash_attn_fwd_softmax_lse_correction, you should always update the initialized softmax_lse. Some function calls update the whole softmax_lse, some only update half of it, but all function calls should work on the same tensor. Otherwise, the final softmax_lse is wrong. |
@Kite0011 out of curiosity, you indeed hit a case where the value is even out of the range of double data type? |
Thank you for your response, I will modify my PR according to this link later. |
Thank you for your reply, you are correct, the |
Yeah, I think changing to softmax_lse.copy_(new_scale) should work. Interesting that even double cannot cover the range. However, with your fix, I think now even FP32 softmax_lse should work. Anyway, that needs some further test, you can leave it as double type now, I will try to make a change if nencessary. Thanks for the fix, really appreciate it. |
I've made a revision. Could you help me see if there is anything else that needs to be done? |
I have made a revision, could you please review it again? Thank you for your hard work. & It seems like you're right, lse-related operations should all be safe at this moment; but considering that lse itself doesn't take up much memory and time, it's also good to keep the data type as double. |
LGTM. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after signing the commits.
By the way, I would like to ask whether the accumulation of the forward out and the backward dq, dkv should also first be converted to fp32 for accumulation, and then converted back to fp16/bf16 before being copied back to the inplace position ? |
@Infi-zc I also considered this, but with BF16 results correction, I did not see loss curver issue. Did you encounter any case that needs FP32 accumulation? |
Not yet, I'm not quite certain. Plan to conduct more experiments to verify this further. |
Yeah, sounds good, let me know if you encounter the issue. Thanks. |
In my case, I didn't encounter loss diff from bf16's fwd o and bwd's dq, dkv; we have tried fp32's accumulate but the final result didn't make much difference. |
@Kite0011 I see some unrelated commits in this PR - I think this is some rebase issue. Could you resolve that? |
5f8df41
to
b381929
Compare
Signed-off-by: kitefang <[email protected]>
/te-ci pytorch |
Merged, thank you for the contribution @Kite0011 ! |
I've noticed a slight discrepancy, with differences at the level of 1e-3 to 1e-2 at certain steps, when compared to the results without checkpointing (cp) or with checkpointing where out and dqkv are accumulated using fp32. But their convergence trend are consistent. @xrennvidia @Kite0011 |
…#716) [Pytorch] Update context parallel softmax lse correction func. Signed-off-by: kitefang <[email protected]> Co-authored-by: kitefang <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
…#716) [Pytorch] Update context parallel softmax lse correction func. Signed-off-by: kitefang <[email protected]> Co-authored-by: kitefang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
The original implementation would result in 'nan' when the value of lse.exp() exceeds the range of double, causing incorrect values and gradients at the corresponding positions.