Major and breaking changes
🐾 Process-supervised RM Trainer
We introduced a new trainer to train Process-supervised Reward Model (PRM) in TRL. A PRM rewards the quality of intermediate steps, promoting structured reasoning over focusing solely on the final outcome.With this trainer, we introduce a new dataset type: Stepwise supervision, which is a variant of the prompt-completion type, but for which completion is divided into several intermediate steps, and each step is associated with a label. Find out more in the stepwise-supervision section in the TRL documentation.
Here is an example of how to use the PRMTrainer
to train a PRM on the Math Shepherd dataset:
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
For more information, check out the PRMTrainer documentation.
by @qgallouedec and @gaetanlop in #2127 and #2148
🔀 Add MergeModelCallBack
Various works show that model merging can non-trivially improve performance, especially if the models belong to the same architecture. TRL now features a callback that merges the reference model with the current policy and optionally pushes the merged checkpoint to the Hub. This could be done on step/epoch end and/or the end of training. This callback uses Arcee's mergekit lib: https://github.com/arcee-ai/mergekit
from trl import DPOTrainer, MergeModelCallback
from trl.mergekit_utils import MergeConfig
config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
by @August-murr in #2282
🔨 Support for tools for data utils
TRL preprocessing utils now support tooling. A first step toward agent fine-tuning.
from trl import apply_chat_template
def get_current_temperature(location: str):
"""
Gets the temperature at a given location.
Args:
location: The location to get the temperature for
"""
return 22.0
example = apply_chat_template(example, tokenizer, tools=[get_current_temperature])
by @August-murr in #2455
🌋 Add support for LLaVA-Next in DPOTrainer
VLMs have their own specificities which require special treatment in the trainer. DPOTrainer
now supports LLaVA-Next models natively.
model = model = AutoModelForVision2Seq.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
trainer = DPOTrainer(model=model, ...)
by @chenweize1998 in #2413
🕹️ CLI and TRLParser refactor
TRL CLI has been refactored to be more user-friendly and easy to extend. We plan to extend the support to all trainers soon.
(simplified output, for readibility)
$ trl dpo --help
usage: trl dpo [-h] --dataset_name DATASET_NAME [--dataset_config DATASET_CONFIG] --output_dir OUTPUT_DIR [--loss_type {sigmoid,hinge,ipo}]
options:
-h, --help show this help message and exit
--dataset_name DATASET_NAME, --dataset-name DATASET_NAME
--dataset_config DATASET_CONFIG, --dataset-config DATASET_CONFIG
--output_dir OUTPUT_DIR, --output-dir OUTPUT_DIR
The output directory where the model predictions and checkpoints will be written. (default: None)
--loss_type {sigmoid,hinge,ipo}, --loss-type {sigmoid,hinge,ipo}
by @qgallouedec in #2380 and #2412
🤝 Mixture of judges
TRL features a new judge AllTrueJudge
that unifies the decision of multiple binary judges. This judge implements the Mixture of Judges as described in the CGPO paper.
from trl import AllTrueJudge, BaseBinaryJudge
class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""
def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
judgements = judge.judge(prompts=prompts, completions=completions)
print(judgements) # [0, 1]
by @gaetanlop in #2159
❄️ DPO trainer supports num_logits_to_keep
to save memory
Save memory by only keeping the top num_logits_to_keep
logits in the DPO trainer.
training_args = DPOConfig(..., use_num_logits_to_keep=True)
🗺️ Implementation DiscoPOP Loss
The DiscoPOP paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0).
training_args = DPOConfig(..., loss_type="discopop", discopop_tau=0.05)
🧑🍳 Add precompute batch size argument in DPOTrainer
for reference model
We can now control the batch size for precomputing reference model logits.
training_args = DPOConfig(
...
precompute_ref_log_probs=True,
precompute_ref_batch_size=4,
)
by @SwayamInSync in #2426
📦 Support for packing tokenized datasets for SFT
SFTTrainer
has supported packing datasets for faster training. Now, it support packing tokenized datasets as well.
📉 Add PEFT support for PPOTrainer
PPOTrainer
now supports PEFT for efficient training.
PPOTrainer(
...,
peft_config=peft_config,
)
💾 Deprecate config
in favor of args
in PPOTrainer
config
has been deprecated in favor of args
in PPOTrainer
.
PPOTrainer(
- config=training_args,
+ args=training_args,
)
by @qgallouedec in #2384
👮 Deprecate policy
in favor of model
in PPOTrainer
policy
has been deprecated in favor of model
in PPOTrainer
.
PPOTrainer(
- policy=model,
+ model=model,
)
by @qgallouedec in #2386
What's Changed
- ⏫ Bump dev version to
0.13.0.dev0
by @qgallouedec in #2305 - 📰 Update blog posts in documentation by @qgallouedec in #2319
- ⚰️ Remove deprecated args, script arguments, and PPOv2 by @qgallouedec in #2306
- 🧽 Fix judge doc by @qgallouedec in #2320
- 🪧 Fix slack notification titles by @qgallouedec in #2322
- 🪪 Check with
token_id
instead oftoken
inDPOTrainer
by @qgallouedec in #2324 - Fix wrong truncating index of tensor in DPOTrainer's concatenated_forward() by @yanghh2000 in #2332
- Fix gradient_checkpointing_kwargs assignment in examples by @Galaxy-Husky in #2331
- Bump liger-kernel to 0.4.0 by @ByronHsu in #2333
- DPO trainer supports num_logits_to_keep to save memory by @xyangk in #2129
- 🧞 Add
output_layer
to the list oflm_head_namings
inAutoModelForCausalLMWithValueHead
by @qgallouedec in #2328 - 🫴 Better guide users in error reporting by @qgallouedec in #2327
- 🪡 Various RLOO fixes by @qgallouedec in #2325
- 💣 Remove transformers version check by @xyangk in #2343
- 👈 Add
tokenizer
arg back and add deprecation guidelines by @qgallouedec in #2348 - 🖨️ Fix error text in BCO and KTO tokenizing function by @PhilipMay in #2286
- Adding video llm fine-tuning example by @mfarre in #2336
- 👋 Remove deprecated
tokenizer
argument in BCO, GKD, Iterative SFT, Nash MD and XPO by @qgallouedec in #2349 - ⚖️ Add
use_soft_judge
option toWinRateCallback
by @kashif in #2347 - 🪜 Stepwise supervision dataset type by @qgallouedec in #2148
- 🔮 Inference mode in
GeometricMixtureWrapper.forward
by @kashif in #2345 - 🗃️ Use specified
data_collator
inRLOOTrainer
andPPOTrainer
by @bartoszzuk in #2360 - 📉 Add PEFT support for
PPOTrainer
by @ccs96307 in #2344 - 📃 Fix description for parameter "generate_during_eval" in dpo_config by @dakru012 in #2364
- 🗺️ Implementation DiscoPOP Loss by @fanconic in #2323
- 🤝 Mixture of judges by @gaetanlop in #2159
- 🎲 Move random judges in testing utilities by @qgallouedec in #2365
- Fix dev install by @lewtun in #2369
- [winrate callback] remove redundant call to eval and train by @kashif in #2372
- 🧲 Use our own
require_bitsandbytes
by @qgallouedec in #2370 - ⏰ Add
start_time
to_maybe_log_save_evaluate
by @qgallouedec in #2373 - 🔀 Add
MergeModelCallBack
by @August-murr in #2282 - 📝 Fix typo in dataset generation script by @jiseshen in #2379
- ⌛ Update log method to include
start_time
parameter by @qgallouedec in #2381 - 🙈 Suppress warning for estimating tokens in trainers by @qgallouedec in #2389
- 📦 Support for packing tokenized datasets for SFT by @kmehant in #2011
- 💾 Deprecate
config
in favor ofargs
inPPOTrainer
by @qgallouedec in #2384 - 🤏 New models for tests by @qgallouedec in #2287
- 👮 Deprecate
policy
in favor ofmodel
inPPOTrainer
by @qgallouedec in #2386 - 🤐 Fix deprecation warnings by @qgallouedec in #2392
- 🤐 Fix deprecation warnings by @qgallouedec in #2395
- 🖋️ Fix warning message formatting in
KTOTrainer
by @qgallouedec in #2394 - 🧳 Move zen generation script and fix tests by @qgallouedec in #2393
- 🐢 Fix slow tests by @kashif in #2397
- 🗝️ Update type hints by @qgallouedec in #2399
- 🖨 Add Script Utilities section to the documentation by @qgallouedec in #2407
- 👁️ Added SFT support for
SmolVLM
models via standalone scriptsft_vlm_smol_vlm.py
by @sergiopaniego in #2409 - Add note about special tokens in chat templates for LoRA SFT by @lewtun in #2414
- 🌐 Community Tutorials by @burtenshaw in #2411
- 🔓 Remove lm_head check in
AutoModelForCausalLMWithValueHead
by @qgallouedec in #2398 - 🌋 Add support for LLaVA-Next in
DPOTrainer
by @chenweize1998 in #2413 ⚠️ Add warning guidelines and update codebase to follow best practices by @qgallouedec in #2350- Super tiny typo fix by @fzyzcjy in #2419
- 🧑🍳 Add precompute batch size argument in
DPOTrainer
for reference model by @SwayamInSync in #2426 - 📑 Refactor
TrlParser
by @qgallouedec in #2412 - 🔮 Fix unused precomputed ref log probs in DPO by @dakru012 in #2431
- 🧮 Fix
max_steps
calculation inRLOOTrainer
by @qgallouedec in #2433 - 🗂️ Harmonize run and example batch sizes in RLOO docs by @asparius in #2439
- 🔗 Add "Open in Colab" badges in community tutorials page by @qgallouedec in #2441
- ©️ Copyrights update by @qgallouedec in #2454
- 💬 Fix chat for windows by @qgallouedec in #2443
- 🆔 Add
datast_config
toScriptArguments
by @qgallouedec in #2440 - 🏎 Fix deepspeed preparation of
ref_model
inOnlineDPOTrainer
by @qgallouedec in #2417 - 👯 Standardize
model_args
by @qgallouedec in #2442 - [bugfix] Fix DataCollatorForChatML unexpected generation prompt by @NIL-zhuang in #2450
- ⚖️ Add
tests_latest.yml
workflow file by @qgallouedec in #2457 - 🛠️ Update tests and fix PPO by @kashif in #2463
- 🎞️ Add "Fine-tuning open AI models using Hugging Face TRL" YouTube video to community tutorials by @qgallouedec in #2467
- 🔨 Support for tools for data utils by @August-murr in #2455
- 🐾 Process-supervised RM Trainer by @gaetanlop in #2127
- 🕹️ CLI refactor by @qgallouedec in #2380
- 👀 Add "PaliGemma 🤝 Direct Preference Optimization" in community tutorials by @qgallouedec in #2475
- ☄️ Add support for Comet experiment management SDK integration by @yaricom in #2462
- 📥 Fix missing
BitsAndBytesConfig
import in doc by @August-murr in #2478 - 👨🏫 smol course links and badges by @qgallouedec in #2484
New Contributors
- @yanghh2000 made their first contribution in #2332
- @Galaxy-Husky made their first contribution in #2331
- @ByronHsu made their first contribution in #2333
- @xyangk made their first contribution in #2129
- @mfarre made their first contribution in #2336
- @dakru012 made their first contribution in #2364
- @fanconic made their first contribution in #2323
- @jiseshen made their first contribution in #2379
- @kmehant made their first contribution in #2011
- @burtenshaw made their first contribution in #2411
- @chenweize1998 made their first contribution in #2413
- @fzyzcjy made their first contribution in #2419
- @SwayamInSync made their first contribution in #2426
- @asparius made their first contribution in #2439
- @NIL-zhuang made their first contribution in #2450
- @yaricom made their first contribution in #2462
Full Changelog: v0.12.0...v0.13.0