diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx index d9f75973ef..148d9c9571 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.mdx @@ -199,21 +199,21 @@ unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", " Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer. -| Trainer | Expected dataset type | -| ----------------------- | ------------------------------------------------------- | -| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) | -| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | -| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | -| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) | -| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | -| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | -| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`PPOTrainer`] | Tokenized language modeling | -| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | -| [`SFTTrainer`] | [Language modeling](#language-modeling) | -| [`XPOTrainer`] | [Prompt-only](#prompt-only) | +| Trainer | Expected dataset type | +| ----------------------- | ------------------------------------------------------------------------------------------------------ | +| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) | +| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | +| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | +| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | +| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | +| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | +| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`PPOTrainer`] | Tokenized language modeling | +| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | +| [`SFTTrainer`] | [Language modeling](#language-modeling) | +| [`XPOTrainer`] | [Prompt-only](#prompt-only) | diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 7c6433be43..dc881f9577 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -2,109 +2,133 @@ [![](https://img.shields.io/badge/All_models-KTO-blue)](https://huggingface.co/models?other=kto,trl) -TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://huggingface.co/papers/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela. -For a full example have a look at [`examples/scripts/kto.py`]. +## Overview -Depending on how good your base model is, you may or may not need to do SFT before KTO. -This is different from standard RLHF and DPO, which always require SFT. -You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below). +Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela). -## Expected dataset type -The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns: +The abstract from the paper is the following: -- `prompt` -- `completion` -- `label` +> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive. -for example: +The official code can be found in [ContextualAI/HALOs](https://github.com/ContextualAI/HALOs). +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Younes Belkada](https://huggingface.co/ybelkada), [Lewis Tunstall](https://huggingface.co/lewtun) and Pablo Vicente. + +## Quick start + +This example demonstrates how to train a model using the KTO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [KTO Mix 14k](https://huggingface.co/datasets/trl-lib/kto-mix-14k). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_kto.py +from datasets import load_dataset +from trl import KTOConfig, KTOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train") + +training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO", logging_steps=10) +trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() ``` -kto_dataset_dict = { - "prompt": [ - "Hey, hello", - "How are you", - "What is your name?", - "What is your name?", - "Which is the best programming language?", - "Which is the best programming language?", - "Which is the best programming language?", - ], - "completion": [ - "hi nice to meet you", - "leave me alone", - "I don't have a name", - "My name is Mary", - "Python", - "C++", - "Java", - ], - "label": [ - True, - False, - False, - True, - True, - False, - False, - ], -} + +Execute the script using the following command: + +```bash +accelerate launch train_kto.py ``` -where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`). -A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. -In theory, the dataset must contain at least one desirable and one undesirable completion; however, some people have had success running KTO on _only_ desirable or undesirable data (in the latter case, it is best to use a conservative learning rate). +Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kto-qwen2-reward-margin.png) -## Expected model format -The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface). -## Using the `KTOTrainer` +
$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
+<quentin_gallouedec>:
+What is the best programming language?
 
-For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response. 
+<trl-lib/Qwen2-0.5B-KTO>:
+The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:                                                                                  
 
-The `beta` refers to the hyperparameter that controls how quickly the loss saturates, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
+Here are some other factors to consider when choosing a programming language for a project:
 
-The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
-By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
+ 1 JavaScript: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.                                                                   
+ 2 Java: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.                                                                                                                                                            
+ 3 C++: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.                                                                                                                                         
+ 4 Python: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.   
+
- -Every choice of `beta` has a maximum learning rate it will tolerate before learning degenerates. For the default `beta = 0.1', this learning rate is `1e-6` for most models. The lower the beta is, the lower your learning rate should be. In general, we strongly recommend a learning rate between `5e-7` and `5e-6`. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, use more epochs. - +## Expected dataset format - -Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor. - - -```py -training_args = KTOConfig( - beta=0.1, - desirable_weight=1.0, - undesirable_weight=1.0, - learning_rate=5e-7, -) - -kto_trainer = KTOTrainer( - model, - ref_model, - args=training_args, - train_dataset=train_dataset, - processing_class=tokenizer, -) -``` -After this one can then call: +KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones. + +The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate. + +## Example script -```py -kto_trainer.train() +We provide an example script to train a model using the KTO method. The script is available in [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) + +To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command: + +```bash +accelerate launch examples/scripts/kto.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/kto-mix-14k \ + --num_train_epochs 1 \ + --logging_steps 25 \ + --output_dir Qwen2-0.5B-KTO ``` +## Usage tips + ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. -To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + + +### Batch size recommendations + +Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor. + +### Learning rate recommendations + +Each choice of `beta` has a maximum learning rate it can tolerate before learning performance degrades. For the default setting of `beta = 0.1`, the learning rate should typically not exceed `1e-6` for most models. As `beta` decreases, the learning rate should also be reduced accordingly. In general, we strongly recommend keeping the learning rate between `5e-7` and `5e-6`. Even with small datasets, we advise against using a learning rate outside this range. Instead, opt for more epochs to achieve better results. + +### Imbalanced data + +The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples. +By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3. + +## Logged metrics + +While training and evaluating we record the following reward metrics: -This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). -To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). +- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta +- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta +- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards +- `logps/chosen`: the mean log probabilities of the chosen completions +- `logps/rejected`: the mean log probabilities of the rejected completions +- `logits/chosen`: the mean logits of the chosen completions +- `logits/rejected`: the mean logits of the rejected completions +- `kl`: the KL divergence between the policy model and the reference model ## KTOTrainer diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md index 628383c5bd..02d0b9c86b 100644 --- a/docs/source/orpo_trainer.md +++ b/docs/source/orpo_trainer.md @@ -4,7 +4,7 @@ ## Overview -Odds Ratio Preference Optimization (ORPO) wa introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes). +Odds Ratio Preference Optimization (ORPO) was introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes). The abstract from the paper is the following: @@ -95,7 +95,7 @@ accelerate launch examples/scripts/orpo.py \ --dataset_name trl-lib/ultrafeedback_binarized \ --num_train_epochs 1 \ --logging_steps 25 \ - --output_dir Qwen2-0.5B-DPO + --output_dir Qwen2-0.5B-ORPO ``` ## Usage tips diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 84d56ac379..50dbcd5f36 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -55,7 +55,6 @@ --lora_alpha=16 """ -from accelerate import PartialState from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser @@ -65,7 +64,6 @@ ModelConfig, ScriptArguments, get_peft_config, - maybe_unpair_preference_dataset, setup_chat_format, ) @@ -95,24 +93,6 @@ # Load the dataset dataset = load_dataset(script_args.dataset_name) - # If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label) - dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc) - - # Apply chat template - def format_dataset(example): - if isinstance(example["completion"], str): - example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False) - example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False) - else: - example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False) - example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False) - return example - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().local_main_process_first(): - dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc) - # Initialize the KTO trainer trainer = KTOTrainer( model, diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index 6154aeeecf..d5a094a66e 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -42,17 +42,17 @@ def setUp(self): @parameterized.expand( [ - ["gpt2", "kto", True, True], - ["gpt2", "kto", True, False], - ["gpt2", "kto", False, True], - ["gpt2", "kto", False, False], - ["gpt2", "apo_zero_unpaired", True, True], - ["gpt2", "apo_zero_unpaired", True, False], - ["gpt2", "apo_zero_unpaired", False, True], - ["gpt2", "apo_zero_unpaired", False, False], + ("gpt2", "standard_preference", "kto", True, True), + # ("t5", "standard_implicit_prompt_preference", "kto", True, False), # KTO broken for enc-dec + ("gpt2", "standard_unpaired_preference", "kto", False, True), + # ("t5", "conversational_preference", "kto", False, False), + ("gpt2", "conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True), + # ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", True, False), + ("gpt2", "standard_unpaired_preference", "apo_zero_unpaired", False, True), + # ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", False, False), ] ) - def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset): + def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_dataset): with tempfile.TemporaryDirectory() as tmp_dir: training_args = KTOConfig( output_dir=tmp_dir, @@ -68,7 +68,7 @@ def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset): report_to="none", ) - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) if name == "gpt2": model = self.model diff --git a/trl/data_utils.py b/trl/data_utils.py index 569d398b52..146466bd6b 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -199,7 +199,9 @@ def _unpair_row(examples: List[Dict[str, List[Dict[str, str]]]]) -> List[Dict[st return new_rows -def unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int] = None) -> DatasetType: +def unpair_preference_dataset( + dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None +) -> DatasetType: r""" Unpair a preference dataset. @@ -209,6 +211,8 @@ def unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int] = No `"prompt"`. num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. + desc (`str` or `None`, *optional*, defaults to `None`): + Meaningful description to be displayed alongside with the progress bar while mapping examples. Returns: `Dataset`: The unpaired preference dataset. @@ -233,10 +237,12 @@ def unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int] = No {'prompt': 'The sky is', 'completion': ' blue.', 'label': True} ``` """ - return dataset.map(_unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc) + return dataset.map(_unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc, desc=desc) -def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int] = None) -> DatasetType: +def maybe_unpair_preference_dataset( + dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None +) -> DatasetType: r""" Unpair a preference dataset if it is paired. @@ -246,6 +252,8 @@ def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int `"prompt"`. num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. + desc (`str` or `None`, *optional*, defaults to `None`): + Meaningful description to be displayed alongside with the progress bar while mapping examples. Returns: `Dataset` or `DatasetDict`: The unpaired preference dataset if it was paired, otherwise the original dataset. @@ -275,7 +283,7 @@ def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int else: column_names = dataset.column_names if "chosen" in column_names and "rejected" in column_names: - return unpair_preference_dataset(dataset, num_proc=num_proc) + return unpair_preference_dataset(dataset, num_proc=num_proc, desc=desc) else: return dataset @@ -380,6 +388,8 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]: # "chosen": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], # "rejected": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}]} # That's why we check if the prompt is also conversational before deciding not to extract it. + if "chosen" not in example or "rejected" not in example: # not a preference example + return example if "prompt" in example: # Both conversational or both non-conversational chosen_conv = is_conversational({"chosen": example["chosen"]}) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 7f32424812..fa1541bb44 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -48,6 +48,7 @@ from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available +from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset from ..models import PreTrainedModelWrapper, create_reference_model from .kto_config import KTOConfig from .utils import ( @@ -566,11 +567,37 @@ def make_inputs_require_grad(module, input, output): " meaning the auxiliary loss will not be used." ) + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): - # Shuffle the datasets - train_dataset = train_dataset.shuffle(seed=args.data_seed) + # Extract the prompt if needed + train_dataset = train_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" + ) + # Unpair the dataset if needed + train_dataset = maybe_unpair_preference_dataset( + train_dataset, args.dataset_num_proc, desc="Unpairing train dataset" + ) + # Apply the chat template if needed + train_dataset = train_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to train dataset", + ) if eval_dataset is not None: - eval_dataset = eval_dataset.shuffle(seed=args.data_seed) + eval_dataset = eval_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" + ) + eval_dataset = maybe_unpair_preference_dataset( + eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset" + ) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to eval dataset", + ) # Tokenize and prepare the training datasets train_dataset = train_dataset.map(