-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Refine MHA API and add DPA API #653
Conversation
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
/te-ci jax |
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
/te-ci jax |
@denera, thanks for the review. Right, we currently have many duplicated code for three different qkv layouts, I think I can reorganize the code in transformer_engine/jax/csrc/modules.cpp to make it clean. However, I afraid we can't reduce the code for transformer_engine/jax/cpp_extensions.py easily since the core difference between FusedAttn classes and scaledSoftmax classes is that the number of arguments are not the same. You can see that the ScaledMaskSoftmax has different number of inputs, so it has to override the entire |
@zlsh80826 I think that's fair. I agree that it's likely too much to do in this PR. I wanted to bring it up as a question just to have it on our radar. Otherwise I think the rest of it looks good. Thanks!! |
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
/te-ci |
Signed-off-by: Reese Wang <[email protected]>
/te-ci |
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
/te-ci |
JET CI verified. |
Merging this PR. Thanks @zlsh80826! |
This PR does the following changes:
Refine the MHA API to align with TransformerLayer API
a. Replace
num_heads
withnum_attention_heads
.b. Replace
droout_rate
withattention_dropout
.c. Replace
output_layernorm
withinput_layernorm
. where this params is used to apply the layernorm on the inputs.d. Replace
apply_residual_connection_post_layernorm
withreturn_layernorm_output
.e. Replace
fuse_qkv
withfused_qkv_params
.f. The old params are marked as deprecated, and will be removed in the future.
Add DotProductAttention module API
Since JAX doesn't have the view/stride concept, all tensors are always in continuous memory format. To have a unify API for all
qkvpacked
,kvpacked
,qkv separate
with 1 input tensor, 2 input tensors and 3 input tensors respectively. This module acceptsqkv_layout
to parsequery
,key
,value
tensors for the different inputs combinations.Support
BSHD_BSHD_BSHD
fused attention custom call