Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Support various implementations of RoPE. #655

Merged
merged 1 commit into from
Feb 27, 2024

Conversation

mingxu1067
Copy link
Collaborator

  1. Support various pairing approaches of coordinates in RoPE. alternate is to pair index i with i + d/2
    , d is the hidden dimension. 'consecutive' pairs index i with i + 1.

@mingxu1067 mingxu1067 requested review from nouiz and denera February 5, 2024 08:35
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just a minor comment on an (optional) error message for bad RoPE method.

transformer_engine/jax/flax/transformer.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@nouiz nouiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT

@denera
Copy link
Collaborator

denera commented Feb 5, 2024

@nouiz Do we need JET CI for this?

@nouiz
Copy link
Collaborator

nouiz commented Feb 5, 2024

I think so. We discovered this issue as the current Rope wasn't used by JET CI if I understood correctly. So that would be the first CI run that use it.
Do we need to update the script to make use of it, or it would be automatic?

@mingxu1067 mingxu1067 changed the title Support various implementations of RoPE. [JAX] Support various implementations of RoPE. Feb 6, 2024
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@nouiz nouiz added the jax label Feb 6, 2024
@zlsh80826
Copy link
Collaborator

Hi @mingxu1067, I found this PR cause functional regression on the LLaMA model. We can discuss offline in the next week.

@mingxu1067 mingxu1067 force-pushed the mingh/support_variant_rope_impl branch from f61518e to 1c18a40 Compare February 26, 2024 03:50
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@mingxu1067 mingxu1067 force-pushed the mingh/support_variant_rope_impl branch from 1c18a40 to a9bb39e Compare February 26, 2024 06:55
@mingxu1067 mingxu1067 force-pushed the mingh/support_variant_rope_impl branch from a9bb39e to 375209b Compare February 26, 2024 07:31
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@mingxu1067
Copy link
Collaborator Author

@denera @nouiz @zlsh80826
Anna verified that this PR could get the expected accuracy on llama2-7B.
Kindly help review this PR and merge. Thank you.

@denera
Copy link
Collaborator

denera commented Feb 27, 2024

@mingxu1067 LGTM! Thanks for seeing this through.

@denera denera merged commit 8bba5ee into NVIDIA:main Feb 27, 2024
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants