-
Notifications
You must be signed in to change notification settings - Fork 583
[JAX] Fix incorrect calculation of segment pos from segment ids #2523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Fix incorrect calculation of segment pos from segment ids #2523
Conversation
…es and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases Signed-off-by: Kshitij Lakhani <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 |
Signed-off-by: Kshitij Lakhani <[email protected]>
Signed-off-by: Kshitij Lakhani <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 |
Greptile OverviewGreptile SummaryFixed incorrect automatic calculation of
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Test as Test Code
participant SD as SequenceDescriptor
participant Validation as Validation Logic
alt THD Layout (is_thd=True)
Test->>Test: Generate segment_ids and segment_pos
Test->>SD: from_segment_ids_and_pos(segment_ids, segment_pos, is_thd=True, is_load_balanced)
SD->>SD: Expand segment_ids to pair
SD->>SD: Expand segment_pos to pair
SD->>SD: Create SequenceDescriptor
else BSHD Layout with segment_pos
Test->>Test: Generate segment_ids and segment_pos
Test->>SD: from_segment_ids_and_pos(segment_ids, segment_pos, is_thd=False, is_load_balanced)
SD->>SD: Expand segment_ids to pair
SD->>SD: Expand segment_pos to pair
SD->>SD: Create SequenceDescriptor
else BSHD Layout without segment_pos (None)
Test->>Test: Generate only segment_ids
Test->>SD: from_segment_ids_and_pos(segment_ids, None, is_thd=False, is_load_balanced=False)
SD->>SD: Expand segment_ids to pair
SD->>Validation: Check if is_load_balanced
alt is_load_balanced=True
Validation-->>SD: AssertionError!
else is_thd=True
Validation-->>SD: AssertionError!
else Valid BSHD case
Validation->>SD: Issue warning
SD->>SD: Generate default segment_pos using arange
SD->>SD: Create SequenceDescriptor
end
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/attention.py, line 841-842 (link)logic:
q_seg_idsandkv_seg_idsare used here but not defined until line 847. This will cause aNameErrorat runtime whensegment_posisNone.
2 files reviewed, 1 comment
Signed-off-by: Kshitij Lakhani <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/attention.py, line 829-832 (link)style: f-strings in assert messages won't be evaluated until assertion fails
2 files reviewed, 1 comment
Description
SequenceDescriptor's
from_segment_ids_and_pos()accepts thesegment_idsand an optionalsegment_posas input. This class is supposed to serve as a convenience method to do two things:segment_idsandsegment_posin a SequenceDescriptor object for TE to use downstreamsegment_posis not passed, then calculate/extrapolate itThe second functionality gives incorrect results for THD cases and Load Balanced cases as it merely uses an
arangeto calculate thesegment_posnaively. This could result in an incorrect mask for these cases.Type of change
Changes
This PR makes two changes:
from_segment_ids_and_pos():is_thdandis_load_balanced. The defaults are set toFalse- the only case that this function can currently guarantee to support is BSHD with no load balancingfrom_segment_ids_and_pos(), it will assertImpact on user of the API:
segment_idsfor THD or Load balanced cases, it will now assert and the user will have to explicitly pass thesegment_possegment_idsfor a BSHD + No load balanced cases, it will behave as before and no changes are neededChecklist: