From 45e19d25c4e573b8f58b786e0b49172f06256c41 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 4 Nov 2024 17:45:07 +0000 Subject: [PATCH 01/14] Implement DiscoPOP Loss --- docs/source/dpo_trainer.mdx | 4 ++++ tests/test_dpo_trainer.py | 1 + tests/test_trainers_args.py | 2 ++ trl/commands/scripts | 1 + trl/trainer/dpo_config.py | 6 ++++++ trl/trainer/dpo_trainer.py | 17 ++++++++++++++++- 6 files changed, 30 insertions(+), 1 deletion(-) create mode 120000 trl/commands/scripts diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 0b5020dbad..b22f0f639a 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -167,6 +167,10 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. +### DiscoPOP loss + +The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). To use this discovered loss, set the `loss_type` value to `discopop` in the [`DPOConfig`]. Additionally, you can change the `discopop_tau` value to change the shape of the DiscoPOP loss. However, the authors recommed the default value `discopop_tau=0.05`. + ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 3194467c2f..17e3f2a2e0 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -196,6 +196,7 @@ def setUp(self): ["t5", "exo_pair", True], ["gpt2", "apo_zero", True], ["t5", "apo_down", False], + ["gpt2", "discopop", False], ] ) def test_dpo_trainer(self, name, loss_type, pre_compute): diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 55802e9a92..a81a05c2d4 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -163,6 +163,7 @@ def test_dpo(self): ref_model_mixup_alpha=0.5, ref_model_sync_steps=32, rpo_alpha=0.5, + discopop_tau=0.1 ) trainer = DPOTrainer( model="gpt2", ref_model="gpt2", args=training_args, train_dataset=dataset, processing_class=tokenizer @@ -193,6 +194,7 @@ def test_dpo(self): self.assertEqual(trainer.args.ref_model_mixup_alpha, 0.5) self.assertEqual(trainer.args.ref_model_sync_steps, 32) self.assertEqual(trainer.args.rpo_alpha, 0.5) + self.assertEqual(trainer.args.discopop_tau, 0.1) def test_kto(self): tokenizer = AutoTokenizer.from_pretrained("gpt2") diff --git a/trl/commands/scripts b/trl/commands/scripts new file mode 120000 index 0000000000..801dc126c1 --- /dev/null +++ b/trl/commands/scripts @@ -0,0 +1 @@ +/home/azureuser/caf83/trl/examples/scripts/ \ No newline at end of file diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b84dfb47dd..0693e0f57b 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -65,6 +65,7 @@ class DPOConfig(TrainingArguments): - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. use_weighting (`bool`, *optional*, defaults to `False`): Whether or not to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. label_pad_token_id (`int`, *optional*, defaults to `-100`): @@ -132,6 +133,9 @@ class DPOConfig(TrainingArguments): α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. + discopop_tau (`float`, *optional*, defaults to `0.05`): + tau/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls + the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. """ learning_rate: float = 1e-6 @@ -150,6 +154,7 @@ class DPOConfig(TrainingArguments): "aot_pair", "apo_zero", "apo_down", + "discopop", ] = "sigmoid" use_weighting: bool = False label_pad_token_id: int = -100 @@ -176,6 +181,7 @@ class DPOConfig(TrainingArguments): ref_model_mixup_alpha: float = 0.9 ref_model_sync_steps: int = 64 rpo_alpha: Optional[float] = None + discopop_tau: Optional[float] = 0.05 def __post_init__(self): if self.max_target_length is not None: diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index a8192e5905..f9e229b41f 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1019,11 +1019,26 @@ def dpo_loss( losses_chosen = F.sigmoid(self.beta * chosen_logratios) losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) losses = losses_chosen + losses_rejected + + elif self.loss_type == "discopop": + # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) + # This loss was discovered with LLM discovery + pi_logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logits = pi_logratios - ref_logratios + logits = logits * self.beta + # Modulate the mixing coefficient based on the log ratio magnitudes + log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) + logistic_component = -F.logsigmoid(logits) + exp_component = torch.exp(-logits) + # Blend between logistic and exponential component based on log ratio modulation + losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation + return losses else: raise ValueError( f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " - "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down']" + "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down', 'discopop']" ) chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() From c93ee1c264d536351d8d2c054617be563812cae9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 4 Nov 2024 18:23:57 +0000 Subject: [PATCH 02/14] Updated DiscoPOP documentation --- docs/source/dpo_trainer.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index b22f0f639a..cc40f8b842 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -150,6 +150,7 @@ The DPO algorithm supports several loss functions. The loss function can be set | `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. | | `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. | | `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. | +| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). To use this discovered loss, set the `loss_type` value to `discopop` in the [`DPOConfig`]. | ### Label smoothing From f3e9f8169ea4f205259a275fdc91cccb10f86dce Mon Sep 17 00:00:00 2001 From: Claudio Date: Mon, 4 Nov 2024 19:07:11 +0000 Subject: [PATCH 03/14] Corrected docs/source/dpo_trainer.mdx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/dpo_trainer.mdx | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index cc40f8b842..dcc4faa969 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -168,10 +168,6 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. -### DiscoPOP loss - -The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). To use this discovered loss, set the `loss_type` value to `discopop` in the [`DPOConfig`]. Additionally, you can change the `discopop_tau` value to change the shape of the DiscoPOP loss. However, the authors recommed the default value `discopop_tau=0.05`. - ### For Mixture of Experts Models: Enabling the auxiliary loss MOEs are the most efficient if the load is about equally distributed between experts. From 3d56cf318e6ad6eb5fd3ccc9e8024486719e2b35 Mon Sep 17 00:00:00 2001 From: Claudio Date: Mon, 4 Nov 2024 19:08:27 +0000 Subject: [PATCH 04/14] Update docs/source/dpo_trainer.mdx MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/dpo_trainer.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index dcc4faa969..068f18b312 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -150,7 +150,7 @@ The DPO algorithm supports several loss functions. The loss function can be set | `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. | | `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. | | `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. | -| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). To use this discovered loss, set the `loss_type` value to `discopop` in the [`DPOConfig`]. | +| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). | ### Label smoothing From 228c2c9c45e1ca736037dbdcc9de8c09a87fe660 Mon Sep 17 00:00:00 2001 From: Claudio Date: Mon, 4 Nov 2024 19:13:46 +0000 Subject: [PATCH 05/14] Update trl/trainer/dpo_config.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 0693e0f57b..42e0fb1f94 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -63,9 +63,9 @@ class DPOConfig(TrainingArguments): - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) paper. - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. + - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. - - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. use_weighting (`bool`, *optional*, defaults to `False`): Whether or not to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper. label_pad_token_id (`int`, *optional*, defaults to `-100`): From 05aea6221668e4297c62cedaf1b40ecb4ccde52e Mon Sep 17 00:00:00 2001 From: Claudio Date: Mon, 4 Nov 2024 19:14:23 +0000 Subject: [PATCH 06/14] Update trl/trainer/dpo_trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index f9e229b41f..6932b0a1b9 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1038,7 +1038,7 @@ def dpo_loss( else: raise ValueError( f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " - "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down', 'discopop']" + "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', 'apo_down']" ) chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() From e905827930e5d4b35f60fa3f1cc3a4ea2e0bf998 Mon Sep 17 00:00:00 2001 From: Claudio Date: Mon, 4 Nov 2024 19:14:59 +0000 Subject: [PATCH 07/14] Update trl/trainer/dpo_trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 6932b0a1b9..baec8f56e1 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1033,7 +1033,6 @@ def dpo_loss( exp_component = torch.exp(-logits) # Blend between logistic and exponential component based on log ratio modulation losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation - return losses else: raise ValueError( From 5ff1e1adb19b837398626c74a7d6dbba89cf31a8 Mon Sep 17 00:00:00 2001 From: Claudio Date: Mon, 4 Nov 2024 19:15:09 +0000 Subject: [PATCH 08/14] Update trl/trainer/dpo_trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index baec8f56e1..a12adbdb7d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1023,9 +1023,9 @@ def dpo_loss( elif self.loss_type == "discopop": # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) # This loss was discovered with LLM discovery - pi_logratios = chosen_logps - rejected_logps + logratios = chosen_logps - rejected_logps ref_logratios = ref_chosen_logps - ref_rejected_logps - logits = pi_logratios - ref_logratios + logits = logratios - ref_logratios logits = logits * self.beta # Modulate the mixing coefficient based on the log ratio magnitudes log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) From 53d3ca0260463773aa88d9f6085fa767eaa675ba Mon Sep 17 00:00:00 2001 From: Claudio Date: Mon, 4 Nov 2024 19:15:19 +0000 Subject: [PATCH 09/14] Update trl/trainer/dpo_config.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 42e0fb1f94..875b211a33 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -181,7 +181,7 @@ class DPOConfig(TrainingArguments): ref_model_mixup_alpha: float = 0.9 ref_model_sync_steps: int = 64 rpo_alpha: Optional[float] = None - discopop_tau: Optional[float] = 0.05 + discopop_tau: float = 0.05 def __post_init__(self): if self.max_target_length is not None: From 4139972a8b6d5a0e137dbbc4eef4f2a1f36e7e9e Mon Sep 17 00:00:00 2001 From: Claudio Date: Mon, 4 Nov 2024 19:15:29 +0000 Subject: [PATCH 10/14] Update trl/trainer/dpo_config.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- trl/trainer/dpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index 875b211a33..b968eaa80d 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -134,7 +134,7 @@ class DPOConfig(TrainingArguments): weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends `rpo_alpha=1.0`. discopop_tau (`float`, *optional*, defaults to `0.05`): - tau/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls + τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. """ From 01fd3ecff686d29d7fb7de43dde6738aa9278f18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sat, 9 Nov 2024 14:06:21 -0400 Subject: [PATCH 11/14] Update trl/trainer/dpo_config.py --- trl/trainer/dpo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py index b968eaa80d..d417323b58 100644 --- a/trl/trainer/dpo_config.py +++ b/trl/trainer/dpo_config.py @@ -152,9 +152,9 @@ class DPOConfig(TrainingArguments): "sppo_hard", "aot", "aot_pair", + "discopop", "apo_zero", "apo_down", - "discopop", ] = "sigmoid" use_weighting: bool = False label_pad_token_id: int = -100 From e4df418efd63550d4a936ea3ff1d11467f8c4137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 18 Nov 2024 10:57:12 +0000 Subject: [PATCH 12/14] Delete scripts directory --- trl/commands/scripts | 1 - 1 file changed, 1 deletion(-) delete mode 120000 trl/commands/scripts diff --git a/trl/commands/scripts b/trl/commands/scripts deleted file mode 120000 index 801dc126c1..0000000000 --- a/trl/commands/scripts +++ /dev/null @@ -1 +0,0 @@ -/home/azureuser/caf83/trl/examples/scripts/ \ No newline at end of file From 97d8478018364768bdf618e45cbfdfed19353c59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 18 Nov 2024 12:11:46 +0000 Subject: [PATCH 13/14] style --- tests/test_trainers_args.py | 2 +- trl/trainer/dpo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index a81a05c2d4..2f62e658a7 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -163,7 +163,7 @@ def test_dpo(self): ref_model_mixup_alpha=0.5, ref_model_sync_steps=32, rpo_alpha=0.5, - discopop_tau=0.1 + discopop_tau=0.1, ) trainer = DPOTrainer( model="gpt2", ref_model="gpt2", args=training_args, train_dataset=dataset, processing_class=tokenizer diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 2bab4ef043..c699bce2ee 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1021,7 +1021,7 @@ def dpo_loss( losses_chosen = F.sigmoid(self.beta * chosen_logratios) losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) losses = losses_chosen + losses_rejected - + elif self.loss_type == "discopop": # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) # This loss was discovered with LLM discovery From b8dd80e1d8985ebb0bf49b4acd22d9ecfeca1c1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 18 Nov 2024 12:56:40 +0000 Subject: [PATCH 14/14] empty commit