Releases: huggingface/trl
v0.15.0
Major and breaking changes
Coming soon
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #2689
- 📦
trl.templates
in excluded packages by @qgallouedec in #2690 - 📖 Docs fix spelling issues by @nnsW3 in #2682
- 📄 Add GRPO batch size note in docs by @sdpkjc in #2672
- 🙈 Fixed typo in the GRPO documentation by @famouswizard in #2691
- docs: Fix broken "Good First Issue" link in CONTRIBUTING.md by @famouswizard in #2693
- 🧠 Fix typo in "understand" in ppo_trainer.md by @famouswizard in #2695
- ☠️ Remove deprecated by @qgallouedec in #2692
- 💡 Add "Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial" by @qgallouedec in #2697
- 📋 Add eval loss logging during prediction in GRPO by @kashif in #2694
- fix: Fix typo in filename Update ultrafeedback.py by @brawncode in #2699
- 📖 Add GRPOTrainer to README.md by @burtenshaw in #2713
- Improve GRPO example by @lewtun in #2717
- 📖 Nit Fix in Documentation by @ParagEkbote in #2722
- 🏰
num_logits_to_keep
tologits_to_keep
by @qgallouedec in #2721 - 💰 Fix incorrect calculation in Olivia's baguette spending logic by @defiberrys in #2727
- fix: Fix typo in filename in ultrafeedback-prompt.py by @brawncode in #2716
- docs: Fix typos in alias descriptions by @defiberrys in #2729
⚠️ Fix Attention Masking in GRPO by @andyl98 in #2708- 🔂 Use vLLM prefix caching for speedup by @winglian in #2757
- 💔 Decouple loss computing and generation in GRPO by @qgallouedec in #2762
- 📌 vLLM >= 0.7.1 for device fix by @ctjlewis in #2766
- 📐 Add vLLM dtype configuration for GRPO trainer by @joey00072 in #2738
- 📖 Clarification max len in Reward documentation by @ParagEkbote in #2740
- 🔎 Add missing script argument in PPO documentation by @JohnConnor123 in #2720
- 🤖 Properly unwrap torch.compile-ed models in GRPO by @winglian in #2750
- 🔁 🦈 Support iterative GRPO by @shirinyamani in #2700
- 🚧 Add Optional ZeRO-3 Weight Gathering for GRPO in Sequence Generation by @SeungyounShin in #2667
↔️ GRPO: Set max_model_len when initializing vLLM instance by @mirceapricop in #2728- 💡 GRPO vram-efficiency improvement; only compute relevant logprobs by @tyler-romero in #2773
- 🙃 Fix reward function in GRPO example by @junuMoon in #2777
- 💡 Add 'Post training an LLM for reasoning with GRPO in TRL' tutorial by @sergiopaniego in #2785
- 📉 Optimize GRPO memory usage by redefining
per_device_batch_size
as generations per device by @qgallouedec in #2776 - 🆚 Distinguish padding and eos when they differ by @binary-husky in #2793
- 🎯 [SFT] add token accuracy metric by @kashif in #2597
- 📠 Log completions for GRPO by @qgallouedec in #2772
- 🔬 SFT simplification by @qgallouedec in #2405
- ➖ Fix GRPO example in README by @qgallouedec in #2800
- ⛰️ Reduce peak vram consumption with efficient selective log_softmax by @tyler-romero in #2799
- fix: typos in documentation files by @maximevtush in #2804
- 📤 GRPO refactor loading the model weights to vllm by @winglian in #2817
- 🫘 Add
set_seed()
call in GRPO to ensure unique seed for each process by @qgallouedec in #2824 - ⚖️ Add reward weight in multi-reward settings for GRPO by @hesamsheikh in #2676
- 🙌 Share vLLM device with training when only 1 available by @qgallouedec in #2827
- 👴 Update
tokenizer
parameter toprocessing_class
in tests by @qgallouedec in #2828 - 🥾 Allow bootstrap GRPO by @qgallouedec in #2829
- ⚡ Fix GRPO PEFT by @qgallouedec in #2725
- Fix PeftModel check when moving weights to vlllm by @edbeeching in #2850
- 🪆 Fix for Incorrect ValueError Handling in reward_weights in grpo_trainer.py by @loveychen in #2843
- 👨👩👧 GRPO + PEFT + vLLM by @winglian in #2818
New Contributors
- @nnsW3 made their first contribution in #2682
- @sdpkjc made their first contribution in #2672
- @famouswizard made their first contribution in #2691
- @brawncode made their first contribution in #2699
- @ParagEkbote made their first contribution in #2722
- @defiberrys made their first contribution in #2727
- @ctjlewis made their first contribution in #2766
- @joey00072 made their first contribution in #2738
- @JohnConnor123 made their first contribution in #2720
- @shirinyamani made their first contribution in #2700
- @mirceapricop made their first contribution in #2728
- @tyler-romero made their first contribution in #2773
- @junuMoon made their first contribution in #2777
- @binary-husky made their first contribution in #2793
- @maximevtush made their first contribution in #2804
- @hesamsheikh made their first contribution in #2676
- @loveychen made their first contribution in #2843
Full Changelog: v0.9.6...v0.15.0
v0.14.0
Major and breaking changes
👨👨👧👧 GRPO
by @qgallouedec in #2565
What's Changed
- ⚰️ Remove deprecated by @qgallouedec in #2485
- 🗣️ Improve prose for smol course by @burtenshaw in #2487
- 🤩 Add SmolVLM tutorials to Community Tutorials page by @sergiopaniego in #2498
- 🏞️ Proper dataset for documentation images by @qgallouedec in #2499
- 🗂️ Reorganize documentation by @qgallouedec in #2483
- [ORPO] fix orpo chosen-nll loss by @kashif in #2502
- 🏚 Remove unused components by @qgallouedec in #2480
- Update community_tutorials.md by @qgallouedec in #2509
- ❎ Remove RLOO example test by @qgallouedec in #2513
- 👨🍳 Clarify DPO data preparation by @qgallouedec in #2512
- 💧 Generalize
disable_dropout
by @qgallouedec in #2511 - 👬 Rename collator
PreferenceCollator
toDataCollatorForPreference
by @qgallouedec in #2510 - 📦 Packing documentation by @qgallouedec in #2503
- ☄️ Update Comet integration to include LogCompletionsCallback and Trainer.evaluation_loop() by @yaricom in #2501
- Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM by @Abhishek-TAMU in #2158
- 🚜 Use field in dataclasses by @qgallouedec in #2494
- ©️ Update copyrights year by @qgallouedec in #2547
- 🧑🤝🧑 Proper metrics gathering across ranks before logging by @zhc7 in #2474
- ✒️ Fix typo in
formatting_func
's documentation inConstantLengthDataset
by @SamuelLarkin in #2549 - 🕊️ DPO padding free by @qgallouedec in #2520
- ℹ️ XPU support for DPO by @faaany in #2533
- 🔠 Fix SFT truncation documentation by @umbilnm in #2521
- ↩️ Revert ORPO loss changes by @kashif in #2527
- 🎴 Add readme for datasets by @August-murr in #2491
- 💔 Fix dataset type unpair conversion docs by @claralp in #2550
- [RLOO] Reinforce++ by @kashif in #2552
- 🏛️ Improve DPO configuration documentation structure by @qgallouedec in #2561
- ✨ Refine model card method docstring by @qgallouedec in #2566
- 🪄 Minor comment style modif by @qgallouedec in #2582
- 🏎️ vllm for Online DPO by @qgallouedec in #2558
- 🔖 Issues Auto-Labeller by @August-murr in #2542
- 🐛 Simplify bug report template by @qgallouedec in #2585
- [RLOO] fix token_level_kl by @kashif in #2575
- ✂️ Truncate by default by @qgallouedec in #2587
- 🫢 Add
max_prompt_length
parameter in tests by @qgallouedec in #2588 - 🎞️ Fix documentation SFT -
max_seq_length
instead ofmax_length
by @skandermoalla in #2590 - 👨👨👧👧 GRPO by @qgallouedec in #2565
- 🫣 Ignore CLI test for Python 3.9 by @qgallouedec in #2592
- Fix merge error by @qgallouedec in #2595
- 🧰 Tool fine-tuning support DPO by @August-murr in #2479
- 💾 Reduce memory peak in GRPO by adding
max_prompt_length
and loop usage in logp computation by @qgallouedec in #2598 - ⚡ Add uv installation instructions by @stevhliu in #2601
- 🧩 PPO/RLOO/OnlineDPO sequence generation: make deepsped 3 weight gathering optional by @dawidm in #2557
- 🫷 Include stop token in policy model's generation_config by @dawidm in #2528
- ✂️ Reintroduce
truncation_mode
inDPOTrainer
by @anakin87 in #2551 - 👋 Drop MDX by @qgallouedec in #2611
- 💎 Rename an inner var in GRPO to improve clarity by @qgallouedec in #2616
- 🏆 Custom reward function for GRPO and shiny doc by @qgallouedec in #2606
- 🥞 Fix DPO gradient accumulation loss scaling by @winglian in #2615
- 🥞 Fix BCO gradient accumulation loss scaling by @qgallouedec in #2638
- 🍭 Custom reward function for RLOO by @August-murr in #2612
- 🌯 Fix context manager runtime error when gather is disabled by @Superskyyy in #2639
- 🥞 Fix CPO gradient accumulation loss scaling by @qgallouedec in #2645
- 🥞 Fix GRPO gradient accumulation loss scaling by @qgallouedec in #2647
- 🥞 Fix KTO gradient accumulation loss scaling by @qgallouedec in #2648
- 🚛 Provide all columns of the dataset to the reward function by @qgallouedec in #2650
- 👐 DeepSpeed integration for GRPO by @qgallouedec in #2652
- 🔎 Finegrained reward logging for GRPO by @qgallouedec in #2651
- 📍 Disable caching when grad checkpointing enable in GRPO by @qgallouedec in #2653
- 📏 Log completion length in GRPO by @qgallouedec in #2659
- 🌀 Fix GRPO default completion length doc by @andyl98 in #2662
- 🏷️ Add model tags to model trained with GRPO by @qgallouedec in #2663
- 🖊 Fix typos by @omahs in #2673
- ⚡ vLLM for fast generation in GRPO by @qgallouedec in #2600
- 📉 Use
num_logits_to_keep
to reduce memory usage in GRPO by @qgallouedec in #2683
New Contributors
- @Abhishek-TAMU made their first contribution in #2158
- @zhc7 made their first contribution in #2474
- @SamuelLarkin made their first contribution in #2549
- @umbilnm made their first contribution in #2521
- @stevhliu made their first contribution in #2601
- @dawidm made their first contribution in #2557
- @Superskyyy made their first contribution in #2639
- @andyl98 made their first contribution in #2662
- @omahs made their first contribution in #2673
Full Changelog: v0.13.0...v0.14.0
v0.13.0
Major and breaking changes
🐾 Process-supervised RM Trainer
We introduced a new trainer to train Process-supervised Reward Model (PRM) in TRL. A PRM rewards the quality of intermediate steps, promoting structured reasoning over focusing solely on the final outcome.With this trainer, we introduce a new dataset type: Stepwise supervision, which is a variant of the prompt-completion type, but for which completion is divided into several intermediate steps, and each step is associated with a label. Find out more in the stepwise-supervision section in the TRL documentation.
Here is an example of how to use the PRMTrainer
to train a PRM on the Math Shepherd dataset:
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
For more information, check out the PRMTrainer documentation.
by @qgallouedec and @gaetanlop in #2127 and #2148
🔀 Add MergeModelCallBack
Various works show that model merging can non-trivially improve performance, especially if the models belong to the same architecture. TRL now features a callback that merges the reference model with the current policy and optionally pushes the merged checkpoint to the Hub. This could be done on step/epoch end and/or the end of training. This callback uses Arcee's mergekit lib: https://github.com/arcee-ai/mergekit
from trl import DPOTrainer, MergeModelCallback
from trl.mergekit_utils import MergeConfig
config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
by @August-murr in #2282
🔨 Support for tools for data utils
TRL preprocessing utils now support tooling. A first step toward agent fine-tuning.
from trl import apply_chat_template
def get_current_temperature(location: str):
"""
Gets the temperature at a given location.
Args:
location: The location to get the temperature for
"""
return 22.0
example = apply_chat_template(example, tokenizer, tools=[get_current_temperature])
by @August-murr in #2455
🌋 Add support for LLaVA-Next in DPOTrainer
VLMs have their own specificities which require special treatment in the trainer. DPOTrainer
now supports LLaVA-Next models natively.
model = model = AutoModelForVision2Seq.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
trainer = DPOTrainer(model=model, ...)
by @chenweize1998 in #2413
🕹️ CLI and TRLParser refactor
TRL CLI has been refactored to be more user-friendly and easy to extend. We plan to extend the support to all trainers soon.
(simplified output, for readibility)
$ trl dpo --help
usage: trl dpo [-h] --dataset_name DATASET_NAME [--dataset_config DATASET_CONFIG] --output_dir OUTPUT_DIR [--loss_type {sigmoid,hinge,ipo}]
options:
-h, --help show this help message and exit
--dataset_name DATASET_NAME, --dataset-name DATASET_NAME
--dataset_config DATASET_CONFIG, --dataset-config DATASET_CONFIG
--output_dir OUTPUT_DIR, --output-dir OUTPUT_DIR
The output directory where the model predictions and checkpoints will be written. (default: None)
--loss_type {sigmoid,hinge,ipo}, --loss-type {sigmoid,hinge,ipo}
by @qgallouedec in #2380 and #2412
🤝 Mixture of judges
TRL features a new judge AllTrueJudge
that unifies the decision of multiple binary judges. This judge implements the Mixture of Judges as described in the CGPO paper.
from trl import AllTrueJudge, BaseBinaryJudge
class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""
def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
judgements = judge.judge(prompts=prompts, completions=completions)
print(judgements) # [0, 1]
by @gaetanlop in #2159
❄️ DPO trainer supports num_logits_to_keep
to save memory
Save memory by only keeping the top num_logits_to_keep
logits in the DPO trainer.
training_args = DPOConfig(..., use_num_logits_to_keep=True)
🗺️ Implementation DiscoPOP Loss
The DiscoPOP paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0).
training_args = DPOConfig(..., loss_type="discopop", discopop_tau=0.05)
🧑🍳 Add precompute batch size argument in DPOTrainer
for reference model
We can now control the batch size for precomputing reference model logits.
training_args = DPOConfig(
...
precompute_ref_log_probs=True,
precompute_ref_batch_size=4,
)
by @SwayamInSync in #2426
📦 Support for packing tokenized datasets for SFT
SFTTrainer
has supported packing datasets for faster training. Now, it support packing tokenized datasets as well.
📉 Add PEFT support for PPOTrainer
PPOTrainer
now supports PEFT for efficient training.
PPOTrainer(
...,
peft_config=peft_config,
)
💾 Deprecate config
in favor of args
in PPOTrainer
config
has been deprecated in favor of args
in PPOTrainer
.
PPOTrainer(
- config=training_args,
+ args=training_args,
)
by @qgallouedec in #2384
👮 Deprecate policy
in favor of model
in PPOTrainer
policy
has been deprecated in favor of model
in PPOTrainer
.
PPOTrainer(
- policy=model,
+ model=model,
)
by @qgallouedec in #2386
What's Changed
- ⏫ Bump dev version to
0.13.0.dev0
by @qgallouedec in #2305 - 📰 Update blog posts in documentation by @qgallouedec in #2319
- ⚰️ Remove deprecated args, script arguments, and PPOv2 by @qgallouedec in #2306
- 🧽 Fix judge doc by @qgallouedec in #2320
- 🪧 Fix slack notification titles by @qgallouedec in #2322
- 🪪 Check with
token_id
instead oftoken
inDPOTrainer
by @qgallouedec in #2324 - Fix wrong truncating index of tensor in DPOTrainer's concatenated_forward() by @yanghh2000 in #2332
- Fix gradient_checkpointing_kwargs assignment in examples by @Galaxy-Husky in #2331
- Bump liger-kernel to 0.4.0 by @ByronHsu in #2333
- DPO trainer supports num_logits_to_keep to save memory by @xyangk in #2129
- 🧞 Add
output_layer
to the list oflm_head_namings
inAutoModelForCausalLMWithValueHead
by @qgallouedec in #2328 - 🫴 Better guide users in error reporting by @qgallouedec in #2327
- 🪡 Various RLOO fixes by @qgallouedec in #2325
- 💣 Remove transformers version check by @xyangk in #2343
- 👈 Add
tokenizer
arg back and add deprecation guidelines by @qgallouedec in #2348 - 🖨️ Fix error text in BCO and KTO tokenizing function by @PhilipMay in #2286
- Adding video llm fine-tuning example by @mfarre in #2336
- 👋 Remove deprecated
tokenizer
argument in BCO, GKD, Iterative SFT, Nash MD and XPO by @qgallouedec in #2349 - ⚖️ Add
use_soft_judge
option toWinRateCallback
by @kashif in #2347 - 🪜 Stepwise supervision dataset type by @qgallouedec in #2148
- 🔮 Inference mode in
GeometricMixtureWrapper.forward
by @kashif in #2345 - 🗃️ Use specified
data_collator
inRLOOTrainer
andPPOTrainer
by @bartoszzuk in h...
v0.12.2
v0.12.1
What's Changed
- 👈 Add
tokenizer
arg back and add deprecation guidelines by @qgallouedec in #2348
Full Changelog: v0.12.0...v0.12.1
v0.12.0
Major and breaking changes
General reward model support for Online DPO
Online DPO intially only supported a reward model that had the same tokenizer and chat template as the trained model. Now, you can use any reward model.
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import OnlineDPOConfig, OnlineDPOTrainer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, padding_side="left")
reward_model = AutoModelForSequenceClassification.from_pretrained(training_args.reward_model_path, num_labels=1)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name, truncation=True, truncation_side="left")
dataset = load_dataset(script_args.dataset_name)
training_args = OnlineDPOConfig(output_dir="...")
trainer = OnlineDPOTrainer(
model=model,
reward_model=reward_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_processing_class=reward_tokenizer,
)
trainer.train()
by @qgallouedec in #2276
Migration PPOv2
-> PPO
The PPOv2
trainer has been renamed to PPO
. The old PPO
trainer has been removed. PPOv2
is now deprecated and will be removed in the next release.
- trainer = PPOv2Trainer(...)
+ trainer = PPOTrainer(...)
by @qgallouedec in #2174
Refactor ScriptArguments
We had ScriptArguments
, SFTScriptArguments
, DPOScriptArguments
and RewardScriptArguments
. Since they all share mostly the same fields, we've merged them into a single ScriptArguments
class.
SFTScriptArguments
, DPOScriptArguments
and RewardScriptArguments
still exist but are deprecated and will be removed in the next release.
- script_args = DPOScriptArguments(...)
+ script_args = ScriptArguments(...)
by @qgallouedec in #2145
Soft judges for PairRM
The PairRMJudge
now when called via the judge
method has a flag return_scores
that returns the probability scores of the first completion of the pair (instead of the rank of the preferred completion). The logits for the probability score can be scaled by an optional temperature
parameter.
from trl import PairRMJudge
pairrm_judge = PairRMJudge()
prompts = ["Translate 'hello' to French", "What's the capital of Japan?"]
completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
results = pairrm_judge.judge(prompts, completions, return_scores=True)
print(results) # [0.7492601275444031, 0.0005497377132996917]
Use pairwise judges for online methods
The OnlineDPOTrainer
and any trainers that inherit from it (NashMDTrainer
and XPOTrainer
) can now accept an initialized PairwiseJudge
instead of a reward model.
from datasets import load_dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
trainer.train()
Rename trainer arg tokenizer
to processing_class
The tokenizer
argument in the trainers has been renamed to processing_class
to better reflect the fact that it can be not only a tokenizer but also a processor.
- trainer = DPOTrainer(model, args=training_args, train_dataset=dataset, tokenizer=tokenizer)
+ trainer = DPOTrainer(model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
tokenizer
is still supported for SFTTrainer
and DPOTrainer
but deprecated and will be removed in the next release.
by @qgallouedec in #2162
Adding weighted preference optimization (WPO) to DPO
The WPO paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the use_weighting
flag to True
in the [DPOConfig
].
DPOConfig(..., use_weighting=True)
![Screenshot 2024-11-04 at 10 59 38](https://private-user-images.githubusercontent.com/45557362/382715693-544ddc02-bd09-4f21-b8a4-b81c21561a9b.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk0ODQzMjgsIm5iZiI6MTczOTQ4NDAyOCwicGF0aCI6Ii80NTU1NzM2Mi8zODI3MTU2OTMtNTQ0ZGRjMDItYmQwOS00ZjIxLWI4YTQtYjgxYzIxNTYxYTliLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEzVDIyMDAyOFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWJhMDE2NDY5Njk3NzVkNzMyOGQyMGZhMzY3YTc2MDkwMzIyNzIzMDY4ODRiYWE3ZTkzYWMxOWE1MDljOTA3MWEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.aNmRV2Q3OsltfRjxxbeft-ugMnx16Nd0rjrh6cekW2Q)
![Screenshot 2024-11-04 at 10 59 22](https://private-user-images.githubusercontent.com/45557362/382715700-8d5afe9e-89bd-4d00-8483-dd7ba98997e7.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk0ODQzMjgsIm5iZiI6MTczOTQ4NDAyOCwicGF0aCI6Ii80NTU1NzM2Mi8zODI3MTU3MDAtOGQ1YWZlOWUtODliZC00ZDAwLTg0ODMtZGQ3YmE5ODk5N2U3LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEzVDIyMDAyOFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTM2M2QzMWNjMzI0ODFlNzhkOTZkZTYzNmU0YjFiNzBkMjcxMThiODgyMzNkMTUwYmFhZGFkZGE3NzY0MzFhZmImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.91k6a0-YGhyupaKTFQXsgG21osIrEhIhZ9e9hIm_4Y0)
by @gaetanlop in #2141
🃏 Model card for TRL
Using trainer.push_to_hub()
now automatically creates a model card that includes:
- A link to the base model used
- A link to the dataset used for training
- A link to the TRL repository
- Sample demo code
- A link to the associated Weights & Biases run
- A link to the paper detailing the training procedure
- Versions of dependencies
- BibTeX citations for both the training procedure and TRL
All links are properly formatted to allow cross-referencing, enabling traceability back to sources (e.g., the model appears linked on the paper’s page).
IOm_SdRMRwAvjfbB.mp4
by @qgallouedec in #2123
Minor
Conversational dataset support
You can now use conversational datasets directly, without needing to apply a chat template beforehand, for the following trainers:
BCOTrainer
(by @qgallouedec in PR #2107)CPOTrainer
(by @qgallouedec in PR #2144)DPOTrainer
(by @qgallouedec in PR #2131)KTOTrainer
(by @qgallouedec in PR #2248)ORPOTrainer
(by @qgallouedec in PR #2184)
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset(dataset_name, split="train")
# Not needed anymore:
#
# def process(row):
# prompt = tokenizer.apply_chat_template(example["prompt"], tokenize=False, add_generation_prompt=True)
# prompt_chosen = tokenizer.apply_chat_template(example["prompt"] + example["chosen"], tokenize=False)
# chosen = prompt_chosen[len(prompt) :]
# prompt_rejected = tokenizer.apply_chat_template(example["prompt"] + example["rejected"], tokenize=False)
# rejected = prompt_rejected[len(prompt) :]
# return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
#
# dataset = dataset.map(process)
training_args = DPOConfig(output_dir="...")
trainer = DPOTrainer(model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()
Refactor DPO data processing
For more information, see PR #2209.
trl env
for printing system info
You can now use trl env
to print system information, including the platform, Python version, PyTorch version, CUDA device(s), and versions of various libraries.
$ trl env
Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.0
- CUDA device(s): NVIDIA H100 80GB HBM3
- Transformers version: 4.47.0.dev0
- Accelerate version: 0.19.0
- Accelerate config: not found
- Datasets version: 3.0.2
- HF Hub version: 0.26.1
- TRL version: 0.12.0+14ef1ab
- bitsandbytes version: 0.44.1
- DeepSpeed version: 0.15.3
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.13.2
by @qgallouedec in #2104
Sequence-Level KD
From GKD paper:
Sequence-Level KD (Kim & Rush, 2016). SeqKD maximizes the likelihood of high probability sequences generated by the teacher, and can be viewed as supervised FT on teacher-generated outputs.
SeqKD is taken as a baseline in the paper. It is now possible to use Sequence-Level KD in the GKDTrainer
by setting seq_kd=True
in the GKDConfig
.
training_args = GKDConfig(..., seq_kd=True)
Default dataset_text_field
to "text"
Since many users use "text"
as the column name for textual data in datasets, we've made it the default (previously a required argument) in SFTConfig
. Now, specifying dataset_text_field="text"
is no longer necessary.
SFTConfig(
...,
- dataset_text_field="text",
)
by @qgallouedec in #2078
What's Changed
- [SFT] fix neftune_noise_alpha in SFTTrainer by @kashif in #1841
- Standardize
training_args
by @qgallouedec in #2082 - Fix typo in ORPO example. by @skandermoalla in #2092
- Fix Inconsistency with IsShardedQLoRA Setting by @fabianlim in #2089
- Fixes #2087 - _process_tokens for empty prompts in KTOTrainer by @gabikadlecova in #2093
- KTO: fix logits metric, add logits metric to BCOTrainer by ...
v0.11.4
What's Changed
- Fix Inconsistency with IsShardedQLoRA Setting by @fabianlim in #2089
New Contributors
- @fabianlim made their first contribution in #2089
Full Changelog: v0.11.3...v0.11.4
v0.11.3
What's Changed
- [GKD] interpolate in prob. space by @kashif in #2204
- Drop
decoder_input_ids
inDPOTrainer
by @qgallouedec in #2208 - Update incorrect data processing in DataCollatorForChatML by @ruijunfeng in #2172
New Contributors
- @ruijunfeng made their first contribution in #2172
Full Changelog: v0.11.2...v0.11.3
v0.11.2
v0.11.1
Bug fix
Full Changelog: v0.11.0...v0.11.1