Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable users to use their own loss functions + deal with prefetching for grad accum #34198

Merged
merged 20 commits into from
Oct 17, 2024

Conversation

muellerzr
Copy link
Contributor

@muellerzr muellerzr commented Oct 16, 2024

What does this PR do?

In conjunction with #34191, this PR solves the other half of what's needed:

  1. Letting users pass in their own loss functions directly to the Trainer via compute_loss
  2. Prefetching the first gradient_accumulation_steps worth of data each complete step and marking how many samples were seen (num_items_in_batch), which can be passed to a loss function if it takes in num_items_seen (name TBD)

A bit of feedback needed we need to coordinate:

  • Should it be called num_items_in_batch and then passed through to the loss functions as such? Or is there a better name we can think of

Fixes huggingface/trl#2175

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LysandreJik @ArthurZucker

@muellerzr muellerzr marked this pull request as ready for review October 16, 2024 17:29
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, IMO a regression test on the grad norms could be fairly nice!

Comment on lines 2463 to 2472
self.state.num_input_tokens_seen += (
torch.sum(
self.accelerator.gather(
torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
)
.cpu()
.item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make this more readable!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean did this one 🫠

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can split in 3-4 lines 🎐

Comment on lines 3644 to 3645
if (self.label_smoother is not None or self.compute_loss is not None) and "labels" in inputs:
labels = inputs.pop("labels")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmm if people don't pass a loss, we won't use the model's default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will, it stays in inputs and gets passed to the models forward()

src/transformers/trainer.py Outdated Show resolved Hide resolved
@muellerzr
Copy link
Contributor Author

muellerzr commented Oct 17, 2024

A bit more context, full fine-tuning does NOT SEEM TO BE IMPACTED BY THIS (when padding). I am looking into how this directly affects TRL, however things are not as bad as they may seem.

(Below is an example CausalLM result comparing grad accum 4, bs 8 vs bs 32 both before and after this fix)

image

# For now we don't support object detection
try:
num_items_in_batch = sum(
[data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already quickly discussed this with Zach, so this is a more general questions to other reviewers:

Would this line be work for all the different task types we support? Specifically, can we always skip the first item in the sequence, i.e. is the [..., 1:] part valid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For casual auto regressive models it works but won't work in other ones

Comment on lines 2463 to 2472
self.state.num_input_tokens_seen += (
torch.sum(
self.accelerator.gather(
torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
)
.cpu()
.item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can split in 3-4 lines 🎐

src/transformers/trainer.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
@muellerzr muellerzr changed the title [DRAFT] Enable users to use their own loss functions + deal with prefetching for grad accum Enable users to use their own loss functions + deal with prefetching for grad accum Oct 17, 2024
Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a denominator change in the test case

tests/trainer/test_trainer.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge!

@muellerzr muellerzr merged commit 6ba31a8 into main Oct 17, 2024
25 of 26 checks passed
@muellerzr muellerzr deleted the muellerzr-fix-loss-calc branch October 17, 2024 21:01
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
…for grad accum (huggingface#34198)

* bookmark

* Bookmark

* Bookmark

* Actually implement

* Pass in kwarg explicitly

* Adjust for if we do or don't have labels

* Bookmark fix for od

* bookmark

* Fin

* closer

* Negate accelerate grad accum div

* Fixup not training long enough

* Add in compute_loss to take full model output

* Document

* compute_loss -> compute_loss_fn

* Add a test

* Refactor

* Refactor

* Uncomment tests

* Update tests/trainer/test_trainer.py

Co-authored-by: Daniel Han <[email protected]>

---------

Co-authored-by: Daniel Han <[email protected]>
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Oct 21, 2024
…for grad accum (huggingface#34198)

* bookmark

* Bookmark

* Bookmark

* Actually implement

* Pass in kwarg explicitly

* Adjust for if we do or don't have labels

* Bookmark fix for od

* bookmark

* Fin

* closer

* Negate accelerate grad accum div

* Fixup not training long enough

* Add in compute_loss to take full model output

* Document

* compute_loss -> compute_loss_fn

* Add a test

* Refactor

* Refactor

* Uncomment tests

* Update tests/trainer/test_trainer.py

Co-authored-by: Daniel Han <[email protected]>

---------

Co-authored-by: Daniel Han <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Gradient accumulation yields worse results than the equivalent batch size
6 participants