Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RewardTrainer] Tokenize inputs within trainer #2102

Merged
merged 15 commits into from
Sep 24, 2024
42 changes: 2 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,12 @@ from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

# configure trainer
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)

# train
trainer.train()
```

Expand All @@ -120,7 +117,7 @@ trainer.train()
Here is a basic example on how to use the `RewardTrainer`:

```python
from trl import RewardConfig, RewardTrainer
from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
from trl.extras.dataset_formatting import conversations_formatting_function
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
Expand All @@ -132,33 +129,7 @@ model = AutoModelForSequenceClassification.from_pretrained(
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")

def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(
tokenized_rejected["attention_mask"]
)

return new_examples

chosen_fn = conversations_formatting_function(tokenizer, "chosen")
rejected_fn = conversations_formatting_function(tokenizer, "rejected")
dataset = dataset.map(lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)})
dataset = dataset.map(
preprocess_function,
batched=True,
)
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about adding it in the trainer as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this could be nice to make it consistent with the SFTTrainer! I'll push a change and fix the tests.

We should later apply this to the other trainers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! see #2071

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 13b5ed0


training_args = RewardConfig(
per_device_train_batch_size=2,
Expand All @@ -171,8 +142,6 @@ trainer = RewardTrainer(
tokenizer=tokenizer,
train_dataset=dataset,
)

# train
trainer.train()
```

Expand Down Expand Up @@ -210,7 +179,6 @@ trainer = RLOOTrainer(
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
# train
trainer.train()
```

Expand All @@ -219,30 +187,24 @@ trainer.train()
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example on how to use the `DPOTrainer`:

```python
# imports
from trl import DPOConfig, DPOTrainer, maybe_extract_prompt, maybe_apply_chat_template
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

# load preference dataset - needs to be in a specific format
dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
dataset = dataset.map(maybe_extract_prompt)
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

# load trainer
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
args=training_args,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)

# train
trainer.train()
```

Expand Down
37 changes: 2 additions & 35 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@
get_kbit_device_map,
get_peft_config,
get_quantization_config,
maybe_apply_chat_template,
setup_chat_format,
)
from trl.commands.cli_utils import RewardScriptArguments
from trl.extras.dataset_formatting import conversations_formatting_function


tqdm.pandas()
Expand Down Expand Up @@ -115,42 +115,9 @@
#############################
dataset = load_dataset(args.dataset_name)

def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

return new_examples

with PartialState().local_main_process_first():
# Wrap inputs with chat template.
# This assumes the chosen/rejected columns are in the OpenAI messages format.
chosen_fn = conversations_formatting_function(tokenizer, "chosen")
rejected_fn = conversations_formatting_function(tokenizer, "rejected")
dataset = dataset.map(
lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=training_args.dataset_num_proc
)
# Tokenize inputs
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=training_args.dataset_num_proc,
)
# Filter out examples that are too long
dataset = dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= training_args.max_length
and len(x["input_ids_rejected"]) <= training_args.max_length,
num_proc=training_args.dataset_num_proc,
maybe_apply_chat_template, num_proc=training_args.dataset_num_proc, fn_kwargs={"tokenizer": tokenizer}
)

##########
Expand Down
47 changes: 47 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pandas as pd
import torch
import torch.nn as nn
from accelerate import PartialState
from accelerate.utils import gather_object
from datasets import Dataset
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
Expand All @@ -43,6 +44,25 @@
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase") -> Dict[str, List[Any]]:
"""Tokenize a batch from a reward modelling dataset."""
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

return new_examples


class RewardTrainer(Trainer):
r"""
The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the
Expand Down Expand Up @@ -205,6 +225,33 @@ def __init__(
self.use_reward_data_collator = True
else:
self.use_reward_data_collator = False

if "input_ids" not in train_dataset.column_names:
with PartialState().local_main_process_first():
fn_kwargs = {"tokenizer": tokenizer}
train_dataset = train_dataset.map(
_tokenize,
batched=True,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
)
train_dataset = train_dataset.filter(
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
lewtun marked this conversation as resolved.
Show resolved Hide resolved
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
num_proc=args.dataset_num_proc,
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
_tokenize,
fn_kwargs=fn_kwargs,
batched=True,
num_proc=args.dataset_num_proc,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= max_length
and len(x["input_ids_rejected"]) <= max_length,
num_proc=args.dataset_num_proc,
)

super().__init__(
model=model,
args=args,
Expand Down
Loading