Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM #2158

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Abhishek-TAMU
Copy link

@Abhishek-TAMU Abhishek-TAMU commented Oct 3, 2024

What does this PR do?

This PR adds cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q to the batch in DataCollatorForCompletionOnlyLM. This, together with a PR in transformers (link to be added), removes graph breaks in padding-free tuning, allowing for maximum performance to be obtained.
Specifically, these parameters should be generated here (this PR change), outside of the transformers loop, as they incur a cpu-gpu sync that is unavoidable. Otherwise, this cpu-gpu sync happens here, inside the attention call which causes graph breaks and hence the transformers PR removes this call to remove all graph breaks when torch_compile flag is turned on in Training arguments to use in SFTTrainer.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Abhishek-TAMU Abhishek-TAMU changed the title Add Sequence Lengths to Batch in DataCollatorForCompletionOnlyLM Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM Oct 3, 2024
@Abhishek-TAMU Abhishek-TAMU marked this pull request as ready for review October 3, 2024 15:42
@Abhishek-TAMU
Copy link
Author

CC: @kashif @qgallouedec

@kashif kashif added ✨ enhancement New feature or request 🏋 SFT Related to SFT labels Oct 6, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

qgallouedec commented Oct 10, 2024

Hi, thanks for the PR.
Can you provide the link of the PR in transformers? Is it huggingface/transformers#33932?

@qgallouedec
Copy link
Member

Could you provide a simple test to:

  1. Confirm that it is a case of non-functioning.
  2. Verify that this addition resolves it.

It might also be helpful to add a few comments, as these lines are unclear without context.

@qgallouedec qgallouedec added 🐛 bug Something isn't working and removed ✨ enhancement New feature or request labels Oct 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 SFT Related to SFT
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants