Skip to content

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 16, 2025

Description

SequenceDescriptor's from_segment_ids_and_pos() accepts the segment_ids and an optional segment_pos as input. This class is supposed to serve as a convenience method to do two things:

  1. Stuff the segment_ids and segment_pos in a SequenceDescriptor object for TE to use downstream
  2. If the segment_pos is not passed, then calculate/extrapolate it

The second functionality gives incorrect results for THD cases and Load Balanced cases as it merely uses an arange to calculate the segment_pos naively. This could result in an incorrect mask for these cases.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR makes two changes:

  1. Passed two new args to this function from_segment_ids_and_pos() : is_thd and is_load_balanced. The defaults are set to False - the only case that this function can currently guarantee to support is BSHD with no load balancing
  2. If THD or load balanced use cases call this function from_segment_ids_and_pos(), it will assert
  3. The fused attn tests were modified to account for these two new args

Impact on user of the API:

  1. No breaking API changes
  2. If the user was (incorrectly) using this convenience function by only passing the segment_ids for THD or Load balanced cases, it will now assert and the user will have to explicitly pass the segment_pos
  3. If the user was (correctly) using this convenience function by only passing the segment_ids for a BSHD + No load balanced cases, it will behave as before and no changes are needed

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…es and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases

Signed-off-by: Kshitij Lakhani <[email protected]>
@KshitijLakhani KshitijLakhani self-assigned this Dec 16, 2025
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

@KshitijLakhani KshitijLakhani added attention jax bug Something isn't working labels Dec 16, 2025
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

@KshitijLakhani KshitijLakhani marked this pull request as ready for review December 17, 2025 02:16
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 17, 2025

Greptile Overview

Greptile Summary

Fixed incorrect automatic calculation of segment_pos from segment_ids in SequenceDescriptor.from_segment_ids_and_pos() for THD and load-balanced cases.

  • Added two new parameters is_thd and is_load_balanced to explicitly indicate layout type and load balancing state
  • Added assertions to prevent automatic segment_pos calculation (via arange) for THD layouts and load-balanced inputs where this would produce incorrect masks
  • Updated all test calls to pass the new parameters based on qkv_layout.is_thd() and cp_size > 1 and cp_load_balanced
  • Non-breaking change: defaults are False, and users must now explicitly pass segment_pos for THD/load-balanced cases
  • Added warning when segment_pos=None is used to clarify it's only valid for BSHD without load balancing

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - it fixes a correctness bug and adds proper validation
  • Score reflects a well-targeted bug fix with appropriate guardrails. The change correctly prevents incorrect mask calculation for THD and load-balanced cases by requiring explicit segment_pos. Tests are updated consistently. Minor style suggestion on f-string usage in assertions, but doesn't affect functionality.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/attention.py 4/5 Added two parameters (is_thd, is_load_balanced) to from_segment_ids_and_pos() with assertions to prevent incorrect segment_pos calculation for THD and load-balanced cases
tests/jax/test_fused_attn.py 5/5 Updated test calls to from_segment_ids_and_pos() to pass the new is_thd and is_load_balanced parameters based on test configuration

Sequence Diagram

sequenceDiagram
    participant Test as Test Code
    participant SD as SequenceDescriptor
    participant Validation as Validation Logic
    
    alt THD Layout (is_thd=True)
        Test->>Test: Generate segment_ids and segment_pos
        Test->>SD: from_segment_ids_and_pos(segment_ids, segment_pos, is_thd=True, is_load_balanced)
        SD->>SD: Expand segment_ids to pair
        SD->>SD: Expand segment_pos to pair
        SD->>SD: Create SequenceDescriptor
    else BSHD Layout with segment_pos
        Test->>Test: Generate segment_ids and segment_pos
        Test->>SD: from_segment_ids_and_pos(segment_ids, segment_pos, is_thd=False, is_load_balanced)
        SD->>SD: Expand segment_ids to pair
        SD->>SD: Expand segment_pos to pair
        SD->>SD: Create SequenceDescriptor
    else BSHD Layout without segment_pos (None)
        Test->>Test: Generate only segment_ids
        Test->>SD: from_segment_ids_and_pos(segment_ids, None, is_thd=False, is_load_balanced=False)
        SD->>SD: Expand segment_ids to pair
        SD->>Validation: Check if is_load_balanced
        alt is_load_balanced=True
            Validation-->>SD: AssertionError!
        else is_thd=True
            Validation-->>SD: AssertionError!
        else Valid BSHD case
            Validation->>SD: Issue warning
            SD->>SD: Generate default segment_pos using arange
            SD->>SD: Create SequenceDescriptor
        end
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/attention.py, line 841-842 (link)

    logic: q_seg_ids and kv_seg_ids are used here but not defined until line 847. This will cause a NameError at runtime when segment_pos is None.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/attention.py, line 829-832 (link)

    style: f-strings in assert messages won't be evaluated until assertion fails

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

attention bug Something isn't working jax

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant