Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch] Update context parallel softmax lse correction func #716

Merged
merged 1 commit into from
Mar 21, 2024

Conversation

Kite0011
Copy link
Contributor

The original implementation would result in 'nan' when the value of lse.exp() exceeds the range of double, causing incorrect values and gradients at the corresponding positions.

@ptrendx
Copy link
Member

ptrendx commented Mar 13, 2024

Hi @Kite0011, thank you for your contribution! Could you please sign your commits (as outlined here: https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst)?

@xrennvidia
Copy link
Collaborator

Hi @Kite0011 , really appreciate the fix. The math calculation is correct. The problem is that the function should do in-place update.

softmax_lse is initialized at here softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double). Every time you call flash_attn_fwd_softmax_lse_correction, you should always update the initialized softmax_lse. Some function calls update the whole softmax_lse, some only update half of it, but all function calls should work on the same tensor. Otherwise, the final softmax_lse is wrong.

@xrennvidia
Copy link
Collaborator

@Kite0011 out of curiosity, you indeed hit a case where the value is even out of the range of double data type?

@Kite0011
Copy link
Contributor Author

Hi @Kite0011, thank you for your contribution! Could you please sign your commits (as outlined here: https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst)?

Thank you for your response, I will modify my PR according to this link later.

@Kite0011
Copy link
Contributor Author

@Kite0011 out of curiosity, you indeed hit a case where the value is even out of the range of double data type?

Thank you for your reply, you are correct, the softmax_lse that comes in some steps is only its own slice, I will change this function to an inplace operation. (I understand that changing the final return to copy should be enough?) For the second issue, we indeed found a problem in the warm start where the forward calculation could not be aligned, and finally traced it to the fact that the lse of the attn calculation exceeded the range of double after exp.

@xrennvidia
Copy link
Collaborator

xrennvidia commented Mar 14, 2024

@Kite0011 out of curiosity, you indeed hit a case where the value is even out of the range of double data type?

Thank you for your reply, you are correct, the softmax_lse that comes in some steps is only its own slice, I will change this function to an inplace operation. (I understand that changing the final return to copy should be enough?) For the second issue, we indeed found a problem in the warm start where the forward calculation could not be aligned, and finally traced it to the fact that the lse of the attn calculation exceeded the range of double after exp.

Yeah, I think changing to softmax_lse.copy_(new_scale) should work.

Interesting that even double cannot cover the range. However, with your fix, I think now even FP32 softmax_lse should work. Anyway, that needs some further test, you can leave it as double type now, I will try to make a change if nencessary.

Thanks for the fix, really appreciate it.

@Kite0011
Copy link
Contributor Author

Hi @Kite0011, thank you for your contribution! Could you please sign your commits (as outlined here: https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst)?

I've made a revision. Could you help me see if there is anything else that needs to be done?

@Kite0011
Copy link
Contributor Author

Kite0011 commented Mar 14, 2024

@Kite0011 out of curiosity, you indeed hit a case where the value is even out of the range of double data type?

Thank you for your reply, you are correct, the softmax_lse that comes in some steps is only its own slice, I will change this function to an inplace operation. (I understand that changing the final return to copy should be enough?) For the second issue, we indeed found a problem in the warm start where the forward calculation could not be aligned, and finally traced it to the fact that the lse of the attn calculation exceeded the range of double after exp.

Yeah, I think changing to softmax_lse.copy(new_scale)_ should work.

Interesting that even double cannot cover the range. However, with your fix, I think now even FP32 softmax_lse should work. Anyway, that needs some further test, you can leave it as double type now, I will try to make a change if nencessary.

Thanks for the fix, really appreciate it.

I have made a revision, could you please review it again? Thank you for your hard work.

& It seems like you're right, lse-related operations should all be safe at this moment; but considering that lse itself doesn't take up much memory and time, it's also good to keep the data type as double.

@xrennvidia
Copy link
Collaborator

LGTM. Thanks!

Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after signing the commits.

@Infi-zc
Copy link

Infi-zc commented Mar 19, 2024

By the way, I would like to ask whether the accumulation of the forward out and the backward dq, dkv should also first be converted to fp32 for accumulation, and then converted back to fp16/bf16 before being copied back to the inplace position ?
The approach of 'first accumulating in fp32 and then copying' may result in less numerical bias compared to 'directly accumulating in fp16/bf16'. @xrennvidia

@xrennvidia
Copy link
Collaborator

xrennvidia commented Mar 19, 2024

By the way, I would like to ask whether the accumulation of the forward out and the backward dq, dkv should also first be converted to fp32 for accumulation, and then converted back to fp16/bf16 before being copied back to the inplace position ? The approach of 'first accumulating in fp32 and then copying' may result in less numerical bias compared to 'directly accumulating in fp16/bf16'. @xrennvidia

@Infi-zc I also considered this, but with BF16 results correction, I did not see loss curver issue. Did you encounter any case that needs FP32 accumulation?

@Infi-zc
Copy link

Infi-zc commented Mar 19, 2024

By the way, I would like to ask whether the accumulation of the forward out and the backward dq, dkv should also first be converted to fp32 for accumulation, and then converted back to fp16/bf16 before being copied back to the inplace position ? The approach of 'first accumulating in fp32 and then copying' may result in less numerical bias compared to 'directly accumulating in fp16/bf16'. @xrennvidia

@Infi-zc I also considered this, but with BF16 results correction, I did not see loss curver issue. Did you encounter any case that needs FP32 accumulation?

Not yet, I'm not quite certain. Plan to conduct more experiments to verify this further.

@xrennvidia
Copy link
Collaborator

By the way, I would like to ask whether the accumulation of the forward out and the backward dq, dkv should also first be converted to fp32 for accumulation, and then converted back to fp16/bf16 before being copied back to the inplace position ? The approach of 'first accumulating in fp32 and then copying' may result in less numerical bias compared to 'directly accumulating in fp16/bf16'. @xrennvidia

@Infi-zc I also considered this, but with BF16 results correction, I did not see loss curver issue. Did you encounter any case that needs FP32 accumulation?

Not yet, I'm not quite certain. Plan to conduct more experiments to verify this further.

Yeah, sounds good, let me know if you encounter the issue. Thanks.

@Kite0011
Copy link
Contributor Author

Kite0011 commented Mar 19, 2024

By the way, I would like to ask whether the accumulation of the forward out and the backward dq, dkv should also first be converted to fp32 for accumulation, and then converted back to fp16/bf16 before being copied back to the inplace position ? The approach of 'first accumulating in fp32 and then copying' may result in less numerical bias compared to 'directly accumulating in fp16/bf16'. @xrennvidia

In my case, I didn't encounter loss diff from bf16's fwd o and bwd's dq, dkv; we have tried fp32's accumulate but the final result didn't make much difference.
@Infi-zc @xrennvidia

@ptrendx
Copy link
Member

ptrendx commented Mar 19, 2024

@Kite0011 I see some unrelated commits in this PR - I think this is some rebase issue. Could you resolve that?

@Kite0011
Copy link
Contributor Author

Kite0011 commented Mar 20, 2024

@Kite0011 I see some unrelated commits in this PR - I think this is some rebase issue. Could you resolve that?

Thank you for your reminder. It seems there's some issue with my sign-off and rebase operation.
Done~@ptrendx

@Kite0011 Kite0011 force-pushed the debug branch 2 times, most recently from 5f8df41 to b381929 Compare March 20, 2024 06:58
@ptrendx
Copy link
Member

ptrendx commented Mar 20, 2024

/te-ci pytorch

@ptrendx ptrendx merged commit 59bfc17 into NVIDIA:main Mar 21, 2024
18 of 20 checks passed
@ptrendx
Copy link
Member

ptrendx commented Mar 21, 2024

Merged, thank you for the contribution @Kite0011 !

@Infi-zc
Copy link

Infi-zc commented Mar 21, 2024

By the way, I would like to ask whether the accumulation of the forward out and the backward dq, dkv should also first be converted to fp32 for accumulation, and then converted back to fp16/bf16 before being copied back to the inplace position ? The approach of 'first accumulating in fp32 and then copying' may result in less numerical bias compared to 'directly accumulating in fp16/bf16'. @xrennvidia

In my case, I didn't encounter loss diff from bf16's fwd o and bwd's dq, dkv; we have tried fp32's accumulate but the final result didn't make much difference. @Infi-zc @xrennvidia

I've noticed a slight discrepancy, with differences at the level of 1e-3 to 1e-2 at certain steps, when compared to the results without checkpointing (cp) or with checkpointing where out and dqkv are accumulated using fp32. But their convergence trend are consistent. @xrennvidia @Kite0011

kunlunl added a commit to kunlunl/TransformerEngine that referenced this pull request Apr 2, 2024
ksivaman pushed a commit to vasunvidia/TransformerEngine that referenced this pull request Apr 3, 2024
…#716)

[Pytorch] Update context parallel softmax lse correction func.

Signed-off-by: kitefang <[email protected]>
Co-authored-by: kitefang <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
kunlunl added a commit to kunlunl/TransformerEngine that referenced this pull request Apr 22, 2024
kunlunl added a commit to kunlunl/TransformerEngine that referenced this pull request Apr 22, 2024
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 15, 2024
…#716)

[Pytorch] Update context parallel softmax lse correction func.

Signed-off-by: kitefang <[email protected]>
Co-authored-by: kitefang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants