Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧰 Tool fine-tuning support DPO #2479

Merged
merged 8 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,11 +1152,48 @@ def test_dpo_trainer_use_num_logits_to_keep(self):

trainer.train()

def test_padding_free(self):
def test_dpo_trainer_with_tools(self):
model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_id)

# Define dummy test tools
def get_current_temperature(location: str):
"""
Gets the temperature at a given location.

Args:
location: The location to get the temperature for
"""
return 22.0

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
tools=[get_current_temperature],
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference")

trainer = DPOTrainer(
model=model,
ref_model=None,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
# We don't run the training, but at this stage, the dataset is supposed to be pre-processed. When
# pre-processing, we expect the available tools to be explicitly mentioned in the system prompt. That's
# what we're checking here
self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0]))

def test_padding_free(self):
model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Normally, we need `attn_implementation="flash_attention_2"` to that the model returns correct logits.
# Without it, the logits may be incorrect, but that's fine here. This test focuses only on the inner logic
# of padding_free.
Expand Down
12 changes: 11 additions & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional
from typing import Any, Callable, Optional, Union

from transformers import TrainingArguments

Expand Down Expand Up @@ -93,6 +93,9 @@ class DPOConfig(TrainingArguments):
Batch size to use when precomputing reference model log probabilities. This can be set higher than the
training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for
training and `per_device_eval_batch_size` for evaluation.
tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`):
List of tools (callable functions) that will be accessible to the model.
If the template does not support function calling, this argument will have no effect.

> Parameters that control the training

Expand Down Expand Up @@ -261,6 +264,13 @@ class DPOConfig(TrainingArguments):
"`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation."
},
)
tools: Optional[list[Union[dict, Callable]]] = field(
default=None,
metadata={
"help": "List of tools (callable functions) that will be accessible to the model. If the template does "
"not support function calling, this argument will have no effect."
},
)

# Parameters that control the training
learning_rate: float = field(
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,9 @@ def _prepare_dataset(
# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs)
dataset = dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs
)

# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
Expand Down
Loading