Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3.2 1B Instruct on TPU v4, bumping transformers to 4.45.2 #109

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

artus-LYTiQ
Copy link

Added llama3 rope_type implementation and changed default model to Llama 3.2 1B Instruct.

Create an adaptation of the HF transformer's llama3 rope_type implementation in modeling_llama.py.

Updated the dependency to the current transformer library version 4.45.2.

Added more logging to distributed_model.py as the TPU v4-8 vms love to hang at random places when running this code.

Fixes #80

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will wait for the other contribution to be merged before merging this one, but thank you for contributing! Can you confirm the models you have tested with your changes?

tests/akg.py Outdated
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id

def _test_distributed_model_generation(model_id, max_new_tokens=20):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for tests, please create one test similar to tests/test_distributed_model.py (or modify the existing one). To launch it, you can use pytest: python -m pytest -sv /path/to/test_mytest.py::test_my_test_function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for Llama-3.1 (8b) - inference
2 participants