Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Refine MHA API and add DPA API #653

Merged
merged 25 commits into from
Feb 22, 2024

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Feb 5, 2024

This PR does the following changes:

  1. Refine the MHA API to align with TransformerLayer API
    a. Replace num_heads with num_attention_heads.
    b. Replace droout_rate with attention_dropout.
    c. Replace output_layernorm with input_layernorm. where this params is used to apply the layernorm on the inputs.
    d. Replace apply_residual_connection_post_layernorm with return_layernorm_output.
    e. Replace fuse_qkv with fused_qkv_params.
    f. The old params are marked as deprecated, and will be removed in the future.

  2. Add DotProductAttention module API
    Since JAX doesn't have the view/stride concept, all tensors are always in continuous memory format. To have a unify API for all qkvpacked, kvpacked, qkv separate with 1 input tensor, 2 input tensors and 3 input tensors respectively. This module accepts qkv_layout to parse query, key, value tensors for the different inputs combinations.

     * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d].
       key and value arguments in :attr:`__call__()` are ignored in this layout.
     * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked
       tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored.
     * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d].
    
  3. Support BSHD_BSHD_BSHD fused attention custom call

Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

@zlsh80826 zlsh80826 mentioned this pull request Feb 5, 2024
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
@zlsh80826
Copy link
Collaborator Author

/te-ci jax

qa/L0_jax_unittest/test.sh Outdated Show resolved Hide resolved
@zlsh80826
Copy link
Collaborator Author

zlsh80826 commented Feb 6, 2024

@denera, thanks for the review. Right, we currently have many duplicated code for three different qkv layouts, I think I can reorganize the code in transformer_engine/jax/csrc/modules.cpp to make it clean. However, I afraid we can't reduce the code for transformer_engine/jax/cpp_extensions.py easily since the core difference between FusedAttn classes and scaledSoftmax classes is that the number of arguments are not the same. You can see that the ScaledMaskSoftmax has different number of inputs, so it has to override the entire abstract function. The three different FusedAttn layouts have different number of inputs, so the base class can't help. I have some ideas to integrate them, but I afraid that will make this PR too long and delay the DPA release. I would like to open another PR for refactoring the cpp_extensions.py and fused_attn.py (also rename self_attn to fused_attn_qkvpacked, cross_attn to fused_attn_kvpacked), how do you think?

@denera
Copy link
Collaborator

denera commented Feb 6, 2024

@zlsh80826 I think that's fair. I agree that it's likely too much to do in this PR. I wanted to bring it up as a question just to have it on our radar. Otherwise I think the rest of it looks good. Thanks!!

@zlsh80826
Copy link
Collaborator Author

/te-ci

@zlsh80826
Copy link
Collaborator Author

/te-ci

tests/jax/test_praxis_layers.py Show resolved Hide resolved
tests/jax/test_praxis_layers.py Show resolved Hide resolved
transformer_engine/jax/cpp_extensions.py Outdated Show resolved Hide resolved
@zlsh80826
Copy link
Collaborator Author

/te-ci

@nouiz nouiz added the jax label Feb 14, 2024
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zlsh80826
Copy link
Collaborator Author

/te-ci

@zlsh80826
Copy link
Collaborator Author

JET CI verified.

@denera
Copy link
Collaborator

denera commented Feb 22, 2024

Merging this PR. Thanks @zlsh80826!

@denera denera merged commit 9b2fed5 into NVIDIA:main Feb 22, 2024
25 of 28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants