-
Notifications
You must be signed in to change notification settings - Fork 583
Add support for SWA (left, right) with FusedAttention #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for SWA (left, right) with FusedAttention #2477
Conversation
…IA#1369 Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L0 |
Greptile SummaryThis PR adds support for sliding window attention (SWA) with configurable left and right window sizes plus diagonal alignment control to the FusedAttention backend. The implementation plumbs a new Key changes:
Issues found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant TransformerLayer
participant MultiheadAttention
participant DotProductAttention
participant Backend as FusedAttention/FlashAttention
participant CUDA as CUDA C++/cuDNN
User->>TransformerLayer: forward(bottom_right_diagonal)
TransformerLayer->>TransformerLayer: Resolve bottom_right_diagonal<br/>based on mask type
TransformerLayer->>MultiheadAttention: forward(bottom_right_diagonal)
MultiheadAttention->>MultiheadAttention: Resolve bottom_right_diagonal<br/>based on mask type
MultiheadAttention->>DotProductAttention: forward(bottom_right_diagonal)
DotProductAttention->>DotProductAttention: Resolve bottom_right_diagonal<br/>based on mask type
DotProductAttention->>DotProductAttention: Update AttentionParams with<br/>bottom_right_diagonal
DotProductAttention->>DotProductAttention: Select backend based on<br/>window_size and diagonal alignment
alt FusedAttention Path
DotProductAttention->>Backend: FusedAttention.forward(bottom_right_diagonal)
Backend->>CUDA: nvte_fused_attn_fwd(bottom_right_diagonal)
CUDA->>CUDA: Apply SWA with (left, right)<br/>and diagonal alignment
CUDA-->>Backend: Output
else FlashAttention Path
DotProductAttention->>Backend: FlashAttention.forward(window_size)
Note over Backend: FlashAttention only supports<br/>bottom-right diagonal
Backend-->>DotProductAttention: Output
end
Backend-->>DotProductAttention: Attention output
DotProductAttention-->>MultiheadAttention: Return
MultiheadAttention-->>TransformerLayer: Return
TransformerLayer-->>User: Return
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)logic: Trailing comma creates single-element tuple instead of boolean - should this be just
bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]? -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)style: Uses hardcoded mask type check instead of the new
bottom_right_diagonalparameter for ALiBi alignment. Should this usebottom_right_diagonalparameter for consistency?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
15 files reviewed, 8 comments
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
| if self_attn_mask_type in {"causal", "padding_causal"}: | ||
| bottom_right_diagonal = False | ||
| if bottom_right_diagonal is None or self_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?
| if enc_dec_attn_mask_type in {"causal", "padding_causal"}: | ||
| enc_dec_bottom_right_diagonal = False | ||
| if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| enc_dec_bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same logic issue as above - mask type check overrides explicit parameter values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a couple of SWA tests to the CP tests as well? I think it's just a matter of replacing (left,0) with (left, right) and test them out. Thanks!
|
|
||
| // NVTE fused attention FWD with packed QKV | ||
| // DEPRECATED: This API is deprecated. | ||
| // DEPRECATED: This API is deprecated. (Should there be a version by which this is going to be removed? @cyang) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made some changes in #2272, but will see if I can make the 2.11 deadline.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sudhakar Singh <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/pytorch/transformer.py, line 777-783 (link)logic: Conditional logic overwrites explicitly passed parameter. When
self_attn_mask_type in {"causal", "padding_causal"}at line 777, this unconditionally setsbottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.Consider restructuring to only apply defaults when
bottom_right_diagonal is None: -
transformer_engine/pytorch/transformer.py, line 787-793 (link)logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites
enc_dec_bottom_right_diagonal, ignoring explicitly passed values. -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2
15 files reviewed, 4 comments
|
/te-ci pytorch L0 |
1 similar comment
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean
15 files reviewed, 1 comment
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
Signed-off-by: Sudhakar Singh <[email protected]>
|
/te-ci pytorch L0 |
Signed-off-by: Sudhakar Singh <[email protected]>
a56c4d2 to
a545ebf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, 4 comments
|
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 963-964 (link)logic: Dead code: after removing
or max_seqlen_q != max_seqlen_kvfrom line 959,alibi_slopes_shapecan never beNoneinside this block
16 files reviewed, 1 comment
…calculated in `fused_attn_fwd` call as well Signed-off-by: Sudhakar Singh <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 960-963 (link)logic: This condition creates unreachable code. After removing
or max_seqlen_q != max_seqlen_kvfrom line 956, the conditionalibi_slopes_shape is Noneon line 960 can never be true since line 956 already requiresalibi_slopes_shape is not None
26 files reviewed, 1 comment
8876202 to
c42c555
Compare
Signed-off-by: Sudhakar Singh <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/common/fused_attn/fused_attn_fp8.cu, line 1709 (link)logic: hardcoded
trueforbottom_right_diagonalin FP8 forward path - should this be parameterized like F16 path?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/common/fused_attn/fused_attn_fp8.cu, line 2039 (link)logic: hardcoded
trueforbottom_right_diagonalin FP8 backward path - should match forward and be parameterized -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 956 (link)style: removed cross-attention check
max_seqlen_q != max_seqlen_kvfrom alibi filter - verify this works correctly with newbottom_right_diagonalparameter for all scenarios
16 files reviewed, 3 comments
Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (6)
-
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)logic: Variable name mismatch - setting
use_flash_attentionbut should beuse_flash_attention_2based on the context (line 898 checksuse_flash_attention_2) -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)logic: Variable name mismatch - setting
use_flash_attentionbut should beuse_flash_attention_2based on the context (line 927 checksuse_flash_attention_2) -
transformer_engine/pytorch/transformer.py, line 777-778 (link)logic: Unconditionally overrides
bottom_right_diagonaltoFalsewhen mask type is causal, even if user explicitly passed a value in the forward call. Check should only apply whenbottom_right_diagonalisNone. -
transformer_engine/pytorch/transformer.py, line 787-788 (link)logic: Same issue - unconditionally overrides
enc_dec_bottom_right_diagonaltoFalsewhen mask type is causal, even if user explicitly passed a value. Check should only apply whenenc_dec_bottom_right_diagonalisNone. -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1099-1100 (link)logic: Unconditionally overrides
bottom_right_diagonaltoFalsewhen mask type is causal, even if user explicitly passed a value. The check should only apply whenbottom_right_diagonalisNone. -
transformer_engine/pytorch/attention/multi_head_attention.py, line 751-752 (link)logic: Unconditionally overrides
bottom_right_diagonaltoFalsewhen mask type is causal, even if user explicitly passed a value. The check should only apply whenbottom_right_diagonalisNone.
17 files reviewed, 6 comments
Signed-off-by: Sudhakar Singh <[email protected]>
…ansformerEngine into swa_padding_brcm_try2
|
/te-ci pytorch L1 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (6)
-
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/transformer.py, line 777-783 (link)logic: Logic overrides user's explicit parameter. If user passes
bottom_right_diagonal=Trueand mask type iscausal, line 778 forces it toFalse, then line 783 forces it back toTrueif mask type contains "bottom_right". Should only apply defaults when parameter isNone. -
transformer_engine/pytorch/transformer.py, line 787-793 (link)logic: Same logic issue: overrides user's explicit
enc_dec_bottom_right_diagonalparameter -
transformer_engine/pytorch/attention/multi_head_attention.py, line 751-757 (link)logic: Same logic issue as transformer.py: overrides user's explicit
bottom_right_diagonalparameter -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1099-1105 (link)logic: Same logic issue: overrides user's explicit
bottom_right_diagonalparameter
17 files reviewed, 6 comments
Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (6)
-
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)syntax: Variable name mismatch - should be
use_flash_attention_2notuse_flash_attention -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)syntax: Same variable name mismatch - should be
use_flash_attention_2notuse_flash_attention -
transformer_engine/pytorch/transformer.py, line 777-778 (link)logic: Logic unconditionally overrides user-provided
bottom_right_diagonalparameter. If user explicitly passesbottom_right_diagonal=Truewithself_attn_mask_type="causal", it gets overridden toFalse. Should the mask type check only apply whenbottom_right_diagonalis None, not override explicit parameter values? -
transformer_engine/pytorch/transformer.py, line 787-788 (link)logic: Same issue - logic overrides explicit
enc_dec_bottom_right_diagonalparameter values based on mask type. Should this only apply when the parameter is None? -
transformer_engine/pytorch/attention/multi_head_attention.py, line 751-752 (link)logic: Same logic issue as in
transformer.py- mask type check overrides explicitbottom_right_diagonalvalues. The pattern of checking mask type should only apply when the parameter isNone -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1098-1099 (link)logic: Consistent with other files - mask type check overrides explicit
bottom_right_diagonalvalues instead of only applying whenNoneNote: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
17 files reviewed, 6 comments
|
/te-ci pytorch L1 |
Description
FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)
Type of change
Changes
Please list the changes introduced in this PR:
transformer_enginecommonfused_attnfused_attn.cppbottom_right_diagonalparameter to the APIfused_attn_f16_arbitrary_seqlen.cu: addbottom_right_diagonalparameter to the APIfused_attn_fp8.cu: addbottom_right_diagonalparameter to theFADescriptor_v1APIutils.h: addbottom_right_diagonalparameter toFADescriptor_v1APIpytorchtransformer.pybottom_right_diagonalthrough the call stack:TransformerLayer-->SelfAttention/CrossAttentionattentiondot_product_attentionbackends.py:UnfusedDotProductAttentionbottom_right_diagonalparameter to theforwardAPIforward?bottom_right_alignmentis being used in the Alibi call, perhaps this should be correctedFusedAttncustom modulebottom_right_diagonalparameter to theforwardAPIFusedAttentionmodulebottom_right_diagonalthrough the call stackdot_product_attention.pyDotProductAttentionbottom_right_diagonalthrough the call stackbottom_right_diagonalif it'sNoneutils.pyAttentionParamsget_attention_backendmulti_head_attention.pybottom_right_diagonalto forward API and callbottom_right_diagonalif it'sNonecpp_extentionsfused_attn.pybottom_right_diagonalinfused_attn_fwd/fused_attn_bwdcsrcextensionattention.cppbottom_right_diagonalthrough the call stack:fused_attn_fwd-->nvte_fused_attn_fwdextensions.hbottom_right_diagonaltofused_attn_fwdandfused_attn_bwdAPI definitionsChecklist: