Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Fused attention unit tests fixes and refinements #1352

Merged
merged 10 commits into from
Dec 17, 2024

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Dec 2, 2024

Description

Refinments for the JAX test_fused_attn and fix a THD cross attention + causal bug.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Fix the causal mask calculation on cross attention with THD by adding segment_pos
  • Remove the segment_pad in the unit tests, and encoded the padding informations to the segment_ids. (0s means paddings)
  • Moving the deterministic fixture to test_praxis_layers specific, which allows test_fused_attn.py to run with the non-deterministic (faster and widely used) kernels.
  • Add util functions for the qkv layout
  • Separate the "attn_bias_type, bias_shape" for forward and backward to avoid lots of skipped tests.
  • Remove cache decorator for fused attn deterministic, which will cause errors on the unit tests.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zlsh80826 zlsh80826 changed the title [JAX] FA tests refactor [JAX] Fused attention unit tests fixes and refinements Dec 2, 2024
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

1 similar comment
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

Copy link
Collaborator

@mgoldfarb-nvidia mgoldfarb-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Nice pythonic enum improvements

@@ -658,16 +668,6 @@ def check_dqkv(primitive, reference, pad):
)


@pytest.mark.parametrize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch - the number of skipped tests was getting very long

@zlsh80826 zlsh80826 merged commit 7f5c784 into NVIDIA:main Dec 17, 2024
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants