Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RewardTrainer] Tokenize inputs within trainer #2102

Merged
merged 15 commits into from
Sep 24, 2024
43 changes: 2 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,12 @@ from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

# configure trainer
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)

# train
trainer.train()
```

Expand All @@ -120,7 +117,7 @@ trainer.train()
Here is a basic example on how to use the `RewardTrainer`:

```python
from trl import RewardConfig, RewardTrainer
from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
from trl.extras.dataset_formatting import conversations_formatting_function
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
Expand All @@ -131,34 +128,7 @@ model = AutoModelForSequenceClassification.from_pretrained(
)
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")

def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(
tokenized_rejected["attention_mask"]
)

return new_examples

chosen_fn = conversations_formatting_function(tokenizer, "chosen")
rejected_fn = conversations_formatting_function(tokenizer, "rejected")
dataset = dataset.map(lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)})
dataset = dataset.map(
preprocess_function,
batched=True,
)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

training_args = RewardConfig(
per_device_train_batch_size=2,
Expand All @@ -171,8 +141,6 @@ trainer = RewardTrainer(
tokenizer=tokenizer,
train_dataset=dataset,
)

# train
trainer.train()
```

Expand Down Expand Up @@ -210,7 +178,6 @@ trainer = RLOOTrainer(
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
# train
trainer.train()
```

Expand All @@ -219,30 +186,24 @@ trainer.train()
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example on how to use the `DPOTrainer`:

```python
# imports
from trl import DPOConfig, DPOTrainer, maybe_extract_prompt, maybe_apply_chat_template
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

# load preference dataset - needs to be in a specific format
dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
dataset = dataset.map(maybe_extract_prompt)
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

# load trainer
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
args=training_args,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)

# train
trainer.train()
```

Expand Down
61 changes: 7 additions & 54 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
--output_dir Qwen2-0.5B-Reward \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--gradient_accumulation_steps 1 \
--remove_unused_columns False \
--gradient_checkpointing True \
--learning_rate 1.0e-5 \
--logging_steps 25 \
Expand All @@ -32,17 +30,15 @@
python examples/scripts/reward_modeling.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--output_dir Qwen2-0.5B-Reward \
--output_dir Qwen2-0.5B-Reward-LoRA \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--gradient_accumulation_steps 1 \
--remove_unused_columns False \
--gradient_checkpointing True \
--learning_rate 1.0e-5 \
--learning_rate 1.0e-4 \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 50 \
--max_length 2048 /
--max_length 2048 \
--use_peft \
--lora_r 32 \
--lora_alpha 16
Expand All @@ -51,9 +47,7 @@
import warnings

import torch
from accelerate import PartialState
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser

from trl import (
Expand All @@ -66,10 +60,6 @@
setup_chat_format,
)
from trl.commands.cli_utils import RewardScriptArguments
from trl.extras.dataset_formatting import conversations_formatting_function


tqdm.pandas()


if __name__ == "__main__":
Expand All @@ -90,6 +80,7 @@
revision=model_config.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
use_cache=False if training_args.gradient_checkpointing else True,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
Expand All @@ -110,49 +101,11 @@
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
)

#############################
# Load and preprocess dataset
#############################
##############
# Load dataset
##############
dataset = load_dataset(args.dataset_name)

def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

return new_examples

with PartialState().local_main_process_first():
# Wrap inputs with chat template.
# This assumes the chosen/rejected columns are in the OpenAI messages format.
chosen_fn = conversations_formatting_function(tokenizer, "chosen")
rejected_fn = conversations_formatting_function(tokenizer, "rejected")
dataset = dataset.map(
lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=training_args.dataset_num_proc
)
# Tokenize inputs
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=training_args.dataset_num_proc,
)
# Filter out examples that are too long
dataset = dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= training_args.max_length
and len(x["input_ids_rejected"]) <= training_args.max_length,
num_proc=training_args.dataset_num_proc,
)

##########
# Training
##########
Expand Down
Loading
Loading