Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

online DPO evaluation #2228

Open
1 of 4 tasks
woshizouguo opened this issue Oct 14, 2024 · 1 comment · May be fixed by #2231
Open
1 of 4 tasks

online DPO evaluation #2228

woshizouguo opened this issue Oct 14, 2024 · 1 comment · May be fixed by #2231
Labels
🐛 bug Something isn't working 🏋 Online DPO Related to Online DPO

Comments

@woshizouguo
Copy link

woshizouguo commented Oct 14, 2024

System Info

trl=0.11.2

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

for online dpo code

If I add --eval_steps=5 and --eval_strategy=steps, it shows error:

  File "/mnt/task_runtime/trl/examples/scripts/dpo_online.py", line 120, in <module>
    trainer.train()
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 2356, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
  File "/mnt/task_runtime/trl/trl/trainer/online_dpo_trainer.py", line 555, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 2761, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 3666, in evaluate
    output = eval_loop(
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 3857, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/opt/miniconda/lib/python3.9/site-packages/transformers/trainer.py", line 4085, in prediction_step
    outputs = model(**inputs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 186, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 201, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 109, in parallel_apply
    output.reraise()
  File "/opt/miniconda/lib/python3.9/site-packages/torch/_utils.py", line 706, in reraise
    raise exception
TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 84, in _worker
    output = module(*input, **kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'prompt'

Expected behavior

The eval is not using the correct input dataset.

@kashif
Copy link
Collaborator

kashif commented Oct 15, 2024

@wenxindongwork I suspect we will need to have our own prediction_step method as we use our own datacollator instead of the default one, and the tests didn't catch this bug since the eval_steps in the tests were > the max_steps so it never ran the evaluation...

@kashif kashif added the 🐛 bug Something isn't working label Oct 15, 2024
@kashif kashif linked a pull request Oct 15, 2024 that will close this issue
@qgallouedec qgallouedec added the 🏋 Online DPO Related to Online DPO label Oct 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 Online DPO Related to Online DPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants