-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Fused attention unit tests fixes and refinements #1352
Conversation
/te-ci jax L1 |
42c31ab
to
28e4ded
Compare
/te-ci jax L1 |
28e4ded
to
76780a3
Compare
/te-ci jax L1 |
1 similar comment
/te-ci jax L1 |
76780a3
to
7c76d81
Compare
/te-ci jax L1 |
7c76d81
to
fb2c73b
Compare
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
fb2c73b
to
74f0603
Compare
/te-ci jax L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Nice pythonic enum improvements
@@ -658,16 +668,6 @@ def check_dqkv(primitive, reference, pad): | |||
) | |||
|
|||
|
|||
@pytest.mark.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch - the number of skipped tests was getting very long
Description
Refinments for the JAX test_fused_attn and fix a THD cross attention + causal bug.
Type of change
Changes
segment_pos
segment_pad
in the unit tests, and encoded the padding informations to thesegment_ids
. (0s means paddings)test_praxis_layers
specific, which allows test_fused_attn.py to run with the non-deterministic (faster and widely used) kernels."attn_bias_type, bias_shape"
for forward and backward to avoid lots of skipped tests.cache
decorator for fused attn deterministic, which will cause errors on the unit tests.Checklist: