Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Refactor fused attention #711

Merged
merged 6 commits into from
Mar 22, 2024

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Mar 11, 2024

This PR does a few refactorings on JAX fused attention

  • The prefixes self/cross have been renamed to _qkvpacked/_kvpacked to more accurately reflect their meanings.
  • The three separate custom calls for _qkvpacked, _kvpacked, and separate have been consolidated into a single custom call. This simplification is intended to ease the future development works.
  • Add no_mask and padding_causal for mask generalization.
    • For mask with padding(i.e. padding_causal, padding), users are required to supply a runtime mask indicating the positions of padding.
    • For mask without padding (i.e. no_mask, causal_mask), TE will generate masks automatically and can reduce the time needed to calculate padding offsets, thereby improving performance.
    • The documentation has been updated to reflect these changes in mask behavior.
  • Reduce the number of skipped cases in test_fused_attn.py

@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 force-pushed the rewang/refactor-fused-attn-impl-rebased branch from a42cd73 to efb770e Compare March 14, 2024 15:33
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 force-pushed the rewang/refactor-fused-attn-impl-rebased branch from efb770e to 696d38e Compare March 15, 2024 03:55
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 force-pushed the rewang/refactor-fused-attn-impl-rebased branch from 696d38e to 89ccd22 Compare March 15, 2024 14:07
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 requested a review from denera March 18, 2024 03:11
@zlsh80826 zlsh80826 marked this pull request as ready for review March 18, 2024 03:11
Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@denera denera merged commit 8e672ff into NVIDIA:main Mar 22, 2024
15 checks passed
@nouiz nouiz added the jax label Mar 27, 2024
yhtang added a commit to NVIDIA/JAX-Toolbox that referenced this pull request Apr 8, 2024
-
jax-ml/jax@0339928:
jax-ml/jax#20588 applies @olupton's JAX fix
related to a cuInit issue.
- Cherry-picked 2 TE commits that are related to JAX but added after the
v1.5 release:
-
NVIDIA/TransformerEngine@8e672ff:
NVIDIA/TransformerEngine#711
-
NVIDIA/TransformerEngine@bfe21c3:
NVIDIA/TransformerEngine#744
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 15, 2024
* Remove unused headers

Signed-off-by: Reese Wang <[email protected]>

* Unify the fused attn workspace size cpp code

Signed-off-by: Reese Wang <[email protected]>

* Reduce the skipped cases

Signed-off-by: Reese Wang <[email protected]>

* Rename self/cross attention to qkvpacked/kvpacked

Signed-off-by: Reese Wang <[email protected]>

* Update attention mask docs

Signed-off-by: Reese Wang <[email protected]>

* Refine the attn mask implementations

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants