-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Refine MHA API and add DPA API (#653)
* Refine MHA API Signed-off-by: Reese Wang <[email protected]> * Reuse func from the flax Signed-off-by: Reese Wang <[email protected]> * DPA draft Signed-off-by: Reese Wang <[email protected]> * qkv packed draft Signed-off-by: Reese Wang <[email protected]> * Fix test_layer with fused attn Signed-off-by: Reese Wang <[email protected]> * Add attn_bias_type and enhance a few code flow Signed-off-by: Reese Wang <[email protected]> * Move scale_factor from __call__ to init Signed-off-by: Reese Wang <[email protected]> * Enhance the docs Signed-off-by: Reese Wang <[email protected]> * Add DPA public API and tests Signed-off-by: Reese Wang <[email protected]> * Refine docs Signed-off-by: Reese Wang <[email protected]> * Refine docs Signed-off-by: Reese Wang <[email protected]> * Fix conflict Signed-off-by: Reese Wang <[email protected]> * Add qkv separate fused attn Signed-off-by: Reese Wang <[email protected]> * Apply BSHD_BSHD_BSHD format Signed-off-by: Reese Wang <[email protected]> * Remove debug log Signed-off-by: Reese Wang <[email protected]> * Add fused attention layer tests Signed-off-by: Reese Wang <[email protected]> * Add NVTE_FUSED_ATTN docs Signed-off-by: Reese Wang <[email protected]> * Fine-grained fused attn settings Signed-off-by: Reese Wang <[email protected]> * Remove the default value of num_attetnion_head and head_dim Signed-off-by: Reese Wang <[email protected]> * Add teardown for fused attn env Signed-off-by: Reese Wang <[email protected]> * Unify the Optional notation Signed-off-by: Reese Wang <[email protected]> * Fix Pre/Post scale bias comments Signed-off-by: Reese Wang <[email protected]> * Add no_mask tests Signed-off-by: Reese Wang <[email protected]> * Add checkpoint_name for fused attn Signed-off-by: Reese Wang <[email protected]> * Fix the fused attn batcher Signed-off-by: Reese Wang <[email protected]> --------- Signed-off-by: Reese Wang <[email protected]>
- Loading branch information
Showing
15 changed files
with
1,820 additions
and
477 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.