Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add THD format support for Context Parallel #641

Merged
merged 3 commits into from
May 13, 2024

Conversation

kunlunl
Copy link
Contributor

@kunlunl kunlunl commented Jan 30, 2024

Make Context Parallel support THD format.
Currently only the Flash Attention Backend is supported as Fused Attention doesn't support THD format.

@timmoon10 timmoon10 requested a review from cyanguwa February 8, 2024 20:39
@kunlunl kunlunl force-pushed the add_thd_for_cp branch 2 times, most recently from 514b9b6 to afd7fe1 Compare March 14, 2024 13:23
@kunlunl
Copy link
Contributor Author

kunlunl commented Mar 14, 2024

Add some custom CUDA kernels to replace the pytorch native op, to make the THD and BSHD have the same performance when using context parallel.

tests/pytorch/fused_attn/run_fused_attn_with_cp.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions.h Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions.h Outdated Show resolved Hide resolved
transformer_engine/pytorch/csrc/extensions.h Outdated Show resolved Hide resolved
@xrennvidia
Copy link
Collaborator

Hi @kunlunl , it looks much better. I left some simple comments, I will let @cyanguwa to review the cuda code more carefully. Thanks.

@kunlunl kunlunl force-pushed the add_thd_for_cp branch 2 times, most recently from dea2a00 to 43add29 Compare April 22, 2024 16:32
Copy link
Collaborator

@xrennvidia xrennvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@cyanguwa
Copy link
Collaborator

/te-ci pytorch

@cyanguwa
Copy link
Collaborator

@kunlunl thanks for the PR. Could you fix the DCO and lint errors please? Instructions are in the Details link. Thanks.

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Pending context parallel CI.

@cyanguwa
Copy link
Collaborator

cyanguwa commented May 3, 2024

@kunlunl could you also please update the pytest version to 7.2 in qa/L1_pytorch_context_parallel_test/test.sh? Thanks.

@kunlunl kunlunl force-pushed the add_thd_for_cp branch 2 times, most recently from 044b028 to 80686b7 Compare May 6, 2024 10:05
@kunlunl
Copy link
Contributor Author

kunlunl commented May 6, 2024

thanks for the PR. Could you fix the DCO and lint errors please? Instructions are in the Details link. Thanks.

could you also please update the pytest version to 7.2 in qa/L1_pytorch_context_parallel_test/test.sh? Thanks.

@cyanguwa Both are done.

@cyanguwa
Copy link
Collaborator

/te-ci pytorch

@cyanguwa cyanguwa merged commit 476f659 into NVIDIA:main May 13, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants