Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first version of flash_attention for jax #19743

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

vulkomilev
Copy link

This is my first version of the flash attention implementation .It is just for Jax.

@fchollet
Copy link
Member

Thanks for the PR! Have you tried to time it on GPU compared to regular attention? I was under the impression that we were going to need a custom Pallas kernel for this.

@vulkomilev
Copy link
Author

I have used /keras/src/layers/attention/ directory as a template for implementing a flash attention but I don't understand how the mask is generated in the benchmark. I need one but I don't see it

@gbaned
Copy link
Collaborator

gbaned commented Jul 12, 2024

Hi @fchollet Can you please review this PR? Thank you!

@@ -76,6 +82,44 @@ def relu6(x):
return Relu6().symbolic_call(x)
return backend.nn.relu6(x)

@keras_export(["keras.ops.flash_attention", "keras.ops.nn.flash_attention"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello there,

That's a very wonderful addition. I've a bit of doubt there:

I suppose ops.flash_attention is an operation, while nn.flash_attention is just a neural network layer.
The basic difference between these two is that - an operation may not have any trainable parameter with it, while a neural network layer should have trainable parameters.

Am I right till now?

If yes, please provide separate examples of each one of them in the docs!

Best Regards,
Abhas Kumar Sinha

Returns:
A tensor with the same shape as `x`.

Example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enclose the Example with "```".

"""python
    This is an example documentation.
    Example:
    ```
        example = example()
    ```
"""

This helps automated doc renderers to automatically find out code examples from the program docs and render those parts accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

4 participants