-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[C/PyTorch/Jax] Add support for more bias shapes #677
[C/PyTorch/Jax] Add support for more bias shapes #677
Conversation
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
…ze' into fused_attn/add_dbias_shapes_c_pytorch Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
/te-ci |
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
/te-ci jax |
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
/te-ci jax |
Pytorch pipeline with CUDA 12.3: 13014488. Fixing A100 errors now. |
Signed-off-by: Charlene Yang <[email protected]>
Pipeline 13043757 for PyTorch and Paddle has passed. |
Signed-off-by: Alp Dener <[email protected]>
/te-ci jax |
…h JAX Signed-off-by: Alp Dener <[email protected]>
…and h_q == h_kv conditions Signed-off-by: Alp Dener <[email protected]>
With the latest commit, all FP16 tests in TE/JAX CI are now passing with The BF16 failures look like this:
It's not clear to me yet if these BF16 failures are due to a bug in the pure-JAX/Flax reference function or the TE/JAX fused attn custom op. |
…27 for Bfloat16 and -2**15 for Float16 Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
/te-ci jax |
Signed-off-by: Alp Dener <[email protected]>
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR
[b,1,s,s]
,[b,h,s,s]
,[1,1,s,s]
bias shapes, when dBias is not required. This is applicable to inference (bias.requires_grad = False
), when bias is a workaround for an arbitrary mask (True/False -> 0/-inf
), or when ALiBi slopes tensor is in[b,h]
shape.F16_arbitrary_seqlen
backend of the cuDNN fused attention only.