Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/PyTorch/Jax] Add support for more bias shapes #677

Merged
merged 20 commits into from
Feb 28, 2024

Conversation

cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Feb 21, 2024

This PR

  • adds support for [b,1,s,s],[b,h,s,s], [1,1,s,s] bias shapes, when dBias is not required. This is applicable to inference (bias.requires_grad = False), when bias is a workaround for an arbitrary mask (True/False -> 0/-inf), or when ALiBi slopes tensor is in [b,h] shape.
  • makes changes to the F16_arbitrary_seqlen backend of the cuDNN fused attention only.
  • makes changes to C, PyTorch and Jax attention implementations.

@cyanguwa cyanguwa changed the title [C/PyTorch] Add b1ss/bhss/11ss bias shapes when not requiring dBias [C/PyTorch] Add b1ss/bhss/11ss bias shapes when dBias is not required Feb 21, 2024
@cyanguwa cyanguwa changed the title [C/PyTorch] Add b1ss/bhss/11ss bias shapes when dBias is not required [C/PyTorch] Add support for more bias shapes when dBias is not required Feb 21, 2024
@denera
Copy link
Collaborator

denera commented Feb 21, 2024

@cyanguwa I merged #676 into this PR. Still need to update the unit tests to include the new bias shapes.

@cyanguwa cyanguwa changed the title [C/PyTorch] Add support for more bias shapes when dBias is not required [C/PyTorch/Jax] Add support for more bias shapes when dBias is not required Feb 21, 2024
@cyanguwa cyanguwa changed the title [C/PyTorch/Jax] Add support for more bias shapes when dBias is not required [C/PyTorch/Jax] Add support for more bias shapes Feb 21, 2024
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

/te-ci

@denera
Copy link
Collaborator

denera commented Feb 22, 2024

/te-ci jax

@denera
Copy link
Collaborator

denera commented Feb 23, 2024

/te-ci jax

@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Feb 23, 2024

Pytorch pipeline with CUDA 12.3: 13014488. Fixing A100 errors now.

Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

cyanguwa commented Feb 24, 2024

Pipeline 13043757 for PyTorch and Paddle has passed.

@denera
Copy link
Collaborator

denera commented Feb 24, 2024

/te-ci jax

@denera
Copy link
Collaborator

denera commented Feb 27, 2024

With the latest commit, all FP16 tests in TE/JAX CI are now passing with neg_inf = -2^15

The BF16 failures look like this:

  • Any input with max_seqlen > 512 + any mask
  • Inputs with max_seqlen <= 512 and max_seqlen_q != max_seqlen_kv + NO_MASK

It's not clear to me yet if these BF16 failures are due to a bug in the pure-JAX/Flax reference function or the TE/JAX fused attn custom op.

tests/jax/test_fused_attn.py Show resolved Hide resolved
tests/jax/test_fused_attn.py Outdated Show resolved Hide resolved
tests/jax/test_fused_attn.py Outdated Show resolved Hide resolved
tests/jax/test_fused_attn.py Outdated Show resolved Hide resolved
tests/jax/test_fused_attn.py Outdated Show resolved Hide resolved
tests/jax/test_fused_attn.py Show resolved Hide resolved
transformer_engine/jax/cpp_extensions.py Outdated Show resolved Hide resolved
@cyanguwa
Copy link
Collaborator Author

/te-ci jax

@cyanguwa cyanguwa requested a review from zlsh80826 February 27, 2024 22:03
@denera
Copy link
Collaborator

denera commented Feb 28, 2024

/te-ci jax

Copy link
Collaborator

@zlsh80826 zlsh80826 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

tests/jax/test_fused_attn.py Show resolved Hide resolved
@cyanguwa cyanguwa merged commit b8eea8a into NVIDIA:main Feb 28, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants