-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Refactor fused attention #711
Merged
denera
merged 6 commits into
NVIDIA:main
from
zlsh80826:rewang/refactor-fused-attn-impl-rebased
Mar 22, 2024
Merged
[JAX] Refactor fused attention #711
denera
merged 6 commits into
NVIDIA:main
from
zlsh80826:rewang/refactor-fused-attn-impl-rebased
Mar 22, 2024
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/te-ci jax |
zlsh80826
force-pushed
the
rewang/refactor-fused-attn-impl-rebased
branch
from
March 14, 2024 15:33
a42cd73
to
efb770e
Compare
/te-ci jax |
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
zlsh80826
force-pushed
the
rewang/refactor-fused-attn-impl-rebased
branch
from
March 15, 2024 03:55
efb770e
to
696d38e
Compare
/te-ci jax |
Signed-off-by: Reese Wang <[email protected]>
zlsh80826
force-pushed
the
rewang/refactor-fused-attn-impl-rebased
branch
from
March 15, 2024 14:07
696d38e
to
89ccd22
Compare
/te-ci jax |
denera
approved these changes
Mar 22, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This was referenced Apr 8, 2024
yhtang
added a commit
to NVIDIA/JAX-Toolbox
that referenced
this pull request
Apr 8, 2024
- jax-ml/jax@0339928: jax-ml/jax#20588 applies @olupton's JAX fix related to a cuInit issue. - Cherry-picked 2 TE commits that are related to JAX but added after the v1.5 release: - NVIDIA/TransformerEngine@8e672ff: NVIDIA/TransformerEngine#711 - NVIDIA/TransformerEngine@bfe21c3: NVIDIA/TransformerEngine#744
pggPL
pushed a commit
to pggPL/TransformerEngine
that referenced
this pull request
May 15, 2024
* Remove unused headers Signed-off-by: Reese Wang <[email protected]> * Unify the fused attn workspace size cpp code Signed-off-by: Reese Wang <[email protected]> * Reduce the skipped cases Signed-off-by: Reese Wang <[email protected]> * Rename self/cross attention to qkvpacked/kvpacked Signed-off-by: Reese Wang <[email protected]> * Update attention mask docs Signed-off-by: Reese Wang <[email protected]> * Refine the attn mask implementations Signed-off-by: Reese Wang <[email protected]> --------- Signed-off-by: Reese Wang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR does a few refactorings on JAX fused attention
self
/cross
have been renamed to_qkvpacked
/_kvpacked
to more accurately reflect their meanings.no_mask
andpadding_causal
for mask generalization.padding
(i.e.padding_causal
,padding
), users are required to supply a runtime mask indicating the positions of padding.padding
(i.e.no_mask
,causal_mask
), TE will generate masks automatically and can reduce the time needed to calculate padding offsets, thereby improving performance.test_fused_attn.py