Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

transformer_engine

  • common

    • fused_attn
      • fused_attn.cpp
        • add bottom_right_diagonal parameter to the API
        • Edit the filters to allow sliding window config to pick arbitrary seqlen fused attn backend
      • fused_attn_f16_arbitrary_seqlen.cu: add bottom_right_diagonal parameter to the API
      • fused_attn_fp8.cu: add bottom_right_diagonal parameter to the FADescriptor_v1 API
      • utils.h: add bottom_right_diagonal parameter to FADescriptor_v1 API
  • pytorch

    • transformer.py
      • plumb bottom_right_diagonal through the call stack: TransformerLayer --> SelfAttention/CrossAttention
    • attention
      • dot_product_attention
        • backends.py:
          • UnfusedDotProductAttention
            • add bottom_right_diagonal parameter to the forward API
              • why is it not used in the forward?
                • bottom_right_alignment is being used in the Alibi call, perhaps this should be corrected
          • FusedAttn custom module
            • add bottom_right_diagonal parameter to the forward API
          • FusedAttention module
            • plumb bottom_right_diagonal through the call stack
        • dot_product_attention.py
          • DotProductAttention
            • Plumb bottom_right_diagonal through the call stack
            • Add calculation of bottom_right_diagonal if it's None
        • utils.py
          • AttentionParams
            • [x]
          • get_attention_backend
            • update sliding window filter section
            • update attention bias filter section
      • multi_head_attention.py
        • Add bottom_right_diagonal to forward API and call
        • Add calculation of bottom_right_diagonal if it's None
    • cpp_extentions
      • fused_attn.py
        • plumb bottom_right_diagonal in fused_attn_fwd/fused_attn_bwd
    • csrc
      • extension
        • attention.cpp
          • plumb bottom_right_diagonal through the call stack: fused_attn_fwd --> nvte_fused_attn_fwd
          • same as above for bwd
      • extensions.h
        • add bottom_right_diagonal to fused_attn_fwd and fused_attn_bwd API definitions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 4, 2025

Greptile Summary

This PR adds support for sliding window attention (SWA) with configurable left and right window sizes plus diagonal alignment control to the FusedAttention backend. The implementation plumbs a new bottom_right_diagonal parameter through the entire attention stack from TransformerLayer down to the C++/CUDA kernels.

Key changes:

  • Added bottom_right_diagonal parameter to control whether sliding window and ALiBi diagonals align to top-left (False) or bottom-right (True) of the softmax matrix
  • Updated C++ backend selection logic to allow arbitrary sequence length fused attention with sliding window configurations (left, right)
  • Extended tests to cover more layout configurations

Issues found:

  • Critical: Variable name mismatch in utils.py lines 911 and 938 - sets use_flash_attention instead of use_flash_attention_2, causing incorrect backend selection
  • Logic issue: Multiple locations override explicit bottom_right_diagonal parameter values based on mask type, instead of only applying defaults when parameter is None. This prevents users from explicitly controlling diagonal alignment
  • Previous thread comments about TODO in backends.py and commented-out code remain unaddressed

Confidence Score: 2/5

  • Not safe to merge - contains critical variable name bug that will cause incorrect backend selection
  • The variable name mismatch (use_flash_attention vs use_flash_attention_2) on lines 911 and 938 will cause the wrong backend flag to be disabled, leading to incorrect attention backend selection. Additionally, the logic that unconditionally overrides explicit parameter values needs review
  • Pay close attention to transformer_engine/pytorch/attention/dot_product_attention/utils.py for the critical variable name bug, and review the parameter override logic in transformer.py, multi_head_attention.py, and dot_product_attention.py

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Added bottom_right_diagonal parameter to AttentionParams and backend selection logic; contains variable name mismatch bug on line 911
transformer_engine/pytorch/transformer.py Added bottom_right_diagonal parameter throughout call stack with logic that overrides explicit parameter values
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Added bottom_right_diagonal parameter and logic to DotProductAttention class with proper ALiBi cache handling
transformer_engine/pytorch/attention/multi_head_attention.py Added bottom_right_diagonal parameter to MultiheadAttention with proper parameter forwarding

Sequence Diagram

sequenceDiagram
    participant User
    participant TransformerLayer
    participant MultiheadAttention
    participant DotProductAttention
    participant Backend as FusedAttention/FlashAttention
    participant CUDA as CUDA C++/cuDNN

    User->>TransformerLayer: forward(bottom_right_diagonal)
    TransformerLayer->>TransformerLayer: Resolve bottom_right_diagonal<br/>based on mask type
    TransformerLayer->>MultiheadAttention: forward(bottom_right_diagonal)
    MultiheadAttention->>MultiheadAttention: Resolve bottom_right_diagonal<br/>based on mask type
    MultiheadAttention->>DotProductAttention: forward(bottom_right_diagonal)
    DotProductAttention->>DotProductAttention: Resolve bottom_right_diagonal<br/>based on mask type
    DotProductAttention->>DotProductAttention: Update AttentionParams with<br/>bottom_right_diagonal
    DotProductAttention->>DotProductAttention: Select backend based on<br/>window_size and diagonal alignment
    alt FusedAttention Path
        DotProductAttention->>Backend: FusedAttention.forward(bottom_right_diagonal)
        Backend->>CUDA: nvte_fused_attn_fwd(bottom_right_diagonal)
        CUDA->>CUDA: Apply SWA with (left, right)<br/>and diagonal alignment
        CUDA-->>Backend: Output
    else FlashAttention Path
        DotProductAttention->>Backend: FlashAttention.forward(window_size)
        Note over Backend: FlashAttention only supports<br/>bottom-right diagonal
        Backend-->>DotProductAttention: Output
    end
    Backend-->>DotProductAttention: Attention output
    DotProductAttention-->>MultiheadAttention: Return
    MultiheadAttention-->>TransformerLayer: Return
    TransformerLayer-->>User: Return
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)

    logic: Trailing comma creates single-element tuple instead of boolean - should this be just bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]?

  2. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)

    style: Uses hardcoded mask type check instead of the new bottom_right_diagonal parameter for ALiBi alignment. Should this use bottom_right_diagonal parameter for consistency?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

15 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +777 to +783
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?

Comment on lines +787 to +793
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same logic issue as above - mask type check overrides explicit parameter values

Copy link
Collaborator

@cyanguwa cyanguwa Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a couple of SWA tests to the CP tests as well? I think it's just a matter of replacing (left,0) with (left, right) and test them out. Thanks!


// NVTE fused attention FWD with packed QKV
// DEPRECATED: This API is deprecated.
// DEPRECATED: This API is deprecated. (Should there be a version by which this is going to be removed? @cyang)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some changes in #2272, but will see if I can make the 2.11 deadline.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/pytorch/transformer.py, line 777-783 (link)

    logic: Conditional logic overwrites explicitly passed parameter. When self_attn_mask_type in {"causal", "padding_causal"} at line 777, this unconditionally sets bottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.

    Consider restructuring to only apply defaults when bottom_right_diagonal is None:

  2. transformer_engine/pytorch/transformer.py, line 787-793 (link)

    logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites enc_dec_bottom_right_diagonal, ignoring explicitly passed values.

  3. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  4. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

15 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

1 similar comment
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)

    syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <[email protected]>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 963-964 (link)

    logic: Dead code: after removing or max_seqlen_q != max_seqlen_kv from line 959, alibi_slopes_shape can never be None inside this block

16 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 960-963 (link)

    logic: This condition creates unreachable code. After removing or max_seqlen_q != max_seqlen_kv from line 956, the condition alibi_slopes_shape is None on line 960 can never be true since line 956 already requires alibi_slopes_shape is not None

26 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/common/fused_attn/fused_attn_fp8.cu, line 1709 (link)

    logic: hardcoded true for bottom_right_diagonal in FP8 forward path - should this be parameterized like F16 path?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. transformer_engine/common/fused_attn/fused_attn_fp8.cu, line 2039 (link)

    logic: hardcoded true for bottom_right_diagonal in FP8 backward path - should match forward and be parameterized

  3. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 956 (link)

    style: removed cross-attention check max_seqlen_q != max_seqlen_kv from alibi filter - verify this works correctly with new bottom_right_diagonal parameter for all scenarios

16 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)

    logic: Variable name mismatch - setting use_flash_attention but should be use_flash_attention_2 based on the context (line 898 checks use_flash_attention_2)

  2. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)

    logic: Variable name mismatch - setting use_flash_attention but should be use_flash_attention_2 based on the context (line 927 checks use_flash_attention_2)

  3. transformer_engine/pytorch/transformer.py, line 777-778 (link)

    logic: Unconditionally overrides bottom_right_diagonal to False when mask type is causal, even if user explicitly passed a value in the forward call. Check should only apply when bottom_right_diagonal is None.

  4. transformer_engine/pytorch/transformer.py, line 787-788 (link)

    logic: Same issue - unconditionally overrides enc_dec_bottom_right_diagonal to False when mask type is causal, even if user explicitly passed a value. Check should only apply when enc_dec_bottom_right_diagonal is None.

  5. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1099-1100 (link)

    logic: Unconditionally overrides bottom_right_diagonal to False when mask type is causal, even if user explicitly passed a value. The check should only apply when bottom_right_diagonal is None.

  6. transformer_engine/pytorch/attention/multi_head_attention.py, line 751-752 (link)

    logic: Unconditionally overrides bottom_right_diagonal to False when mask type is causal, even if user explicitly passed a value. The check should only apply when bottom_right_diagonal is None.

17 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  2. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  3. transformer_engine/pytorch/transformer.py, line 777-783 (link)

    logic: Logic overrides user's explicit parameter. If user passes bottom_right_diagonal=True and mask type is causal, line 778 forces it to False, then line 783 forces it back to True if mask type contains "bottom_right". Should only apply defaults when parameter is None.

  4. transformer_engine/pytorch/transformer.py, line 787-793 (link)

    logic: Same logic issue: overrides user's explicit enc_dec_bottom_right_diagonal parameter

  5. transformer_engine/pytorch/attention/multi_head_attention.py, line 751-757 (link)

    logic: Same logic issue as transformer.py: overrides user's explicit bottom_right_diagonal parameter

  6. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1099-1105 (link)

    logic: Same logic issue: overrides user's explicit bottom_right_diagonal parameter

17 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)

    syntax: Variable name mismatch - should be use_flash_attention_2 not use_flash_attention

  2. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)

    syntax: Same variable name mismatch - should be use_flash_attention_2 not use_flash_attention

  3. transformer_engine/pytorch/transformer.py, line 777-778 (link)

    logic: Logic unconditionally overrides user-provided bottom_right_diagonal parameter. If user explicitly passes bottom_right_diagonal=True with self_attn_mask_type="causal", it gets overridden to False. Should the mask type check only apply when bottom_right_diagonal is None, not override explicit parameter values?

  4. transformer_engine/pytorch/transformer.py, line 787-788 (link)

    logic: Same issue - logic overrides explicit enc_dec_bottom_right_diagonal parameter values based on mask type. Should this only apply when the parameter is None?

  5. transformer_engine/pytorch/attention/multi_head_attention.py, line 751-752 (link)

    logic: Same logic issue as in transformer.py - mask type check overrides explicit bottom_right_diagonal values. The pattern of checking mask type should only apply when the parameter is None

  6. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1098-1099 (link)

    logic: Consistent with other files - mask type check overrides explicit bottom_right_diagonal values instead of only applying when None

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

17 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants