From 7db18dbf7501efa5da0ad7af8bdcab8e2a6dafea Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Sat, 14 Dec 2024 18:29:45 +0000 Subject: [PATCH 1/6] adding tool fine-tuning support for DPO --- trl/trainer/dpo_config.py | 6 +++++- trl/trainer/dpo_trainer.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 88abdd4a5c..41dab64891 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -15,7 +15,7 @@ import warnings from dataclasses import dataclass from enum import Enum -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional,Callable,Union from transformers import TrainingArguments @@ -51,6 +51,9 @@ class DPOConfig(TrainingArguments): label_smoothing (`float`, *optional*, defaults to `0.0`): Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. + tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`): + A list of tools (callable functions) that will be accessible to the model. + If the template does not support function calling, this argument will have no effect loss_type (`str`, *optional*, defaults to `"sigmoid"`): Type of loss to use. Possible values are: @@ -151,6 +154,7 @@ class DPOConfig(TrainingArguments): learning_rate: float = 1e-6 beta: float = 0.1 label_smoothing: float = 0.0 + tools: Optional[list[Union[dict, Callable]]] = None loss_type: Literal[ "sigmoid", "hinge", diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3c4c7771b2..b719c41c8d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -445,7 +445,7 @@ def make_inputs_require_grad(module, input, output): ) train_dataset = train_dataset.map( maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, + fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, num_proc=args.dataset_num_proc, desc="Applying chat template to train dataset", ) @@ -455,7 +455,7 @@ def make_inputs_require_grad(module, input, output): ) eval_dataset = eval_dataset.map( maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, + fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, num_proc=args.dataset_num_proc, desc="Applying chat template to eval dataset", ) From 571a138801cf9c7339f73381b44add99a4568980 Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Sat, 14 Dec 2024 18:47:05 +0000 Subject: [PATCH 2/6] precommit --- trl/trainer/dpo_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 41dab64891..77f2bd0c30 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -15,7 +15,7 @@ import warnings from dataclasses import dataclass from enum import Enum -from typing import Any, Literal, Optional,Callable,Union +from typing import Any, Callable, Literal, Optional, Union from transformers import TrainingArguments @@ -53,7 +53,7 @@ class DPOConfig(TrainingArguments): [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`): A list of tools (callable functions) that will be accessible to the model. - If the template does not support function calling, this argument will have no effect + If the template does not support function calling, this argument will have no effect loss_type (`str`, *optional*, defaults to `"sigmoid"`): Type of loss to use. Possible values are: From dc198495271d20cae8e5455d5110cdb68309ac66 Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Wed, 25 Dec 2024 11:46:18 +0000 Subject: [PATCH 3/6] adding test for DPOTrainer with tool usage --- tests/test_dpo_trainer.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index fe2c732f64..976289d248 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1165,6 +1165,42 @@ def test_dpo_trainer_use_num_logits_to_keep(self): trainer.train() + def test_dpo_trainer_with_tools(self): + model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(model_id) + + # Define dummy test tools + def get_current_temperature(location: str): + """ + Gets the temperature at a given location. + + Args: + location: The location to get the temperature for + """ + return 22.0 + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + tools=[get_current_temperature], + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference") + + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + self.assertIn("get_current_temperature", trainer.train_dataset["prompt"][0]) + @require_vision class DPOVisionTrainerTester(unittest.TestCase): From f80c05d6d840c03b13497d751123b23a010d6823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 20 Jan 2025 17:59:13 +0000 Subject: [PATCH 4/6] style --- trl/trainer/dpo_config.py | 2 +- trl/trainer/dpo_trainer.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 9ef8bffad4..b7c18e11cc 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -15,7 +15,7 @@ import warnings from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Optional, Union from transformers import TrainingArguments diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index e25f52344b..ba946ad8c1 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -542,7 +542,9 @@ def _prepare_dataset( # Apply the chat template if needed if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs) + dataset = dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs + ) # Tokenize the dataset if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` From 4af20768c5945ea489d59ccd129c76c74f362976 Mon Sep 17 00:00:00 2001 From: August Moharrami Date: Mon, 20 Jan 2025 19:20:21 +0000 Subject: [PATCH 5/6] fix test --- tests/test_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 1c34b46c25..d11532c3a5 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1186,7 +1186,7 @@ def get_current_temperature(location: str): eval_dataset=dummy_dataset["test"], ) - self.assertIn("get_current_temperature", trainer.train_dataset["prompt"][0]) + self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0])) def test_padding_free(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" From d7caae5376f1b61480b6dc79b971924e933a0941 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 20 Jan 2025 20:49:52 +0000 Subject: [PATCH 6/6] a comment --- tests/test_dpo_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index d11532c3a5..c4a0232ee3 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1185,7 +1185,9 @@ def get_current_temperature(location: str): train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], ) - + # We don't run the training, but at this stage, the dataset is supposed to be pre-processed. When + # pre-processing, we expect the available tools to be explicitly mentioned in the system prompt. That's + # what we're checking here self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0])) def test_padding_free(self):