diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7b83707b82..41486e1f12 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1933,6 +1933,12 @@ def build_training_arguments(self, total_num_steps): else: training_args_cls = AxolotlDPOConfig + if self.cfg.rl == "ipo": + training_args_kwargs["loss_type"] = "ipo" + training_args_kwargs["max_length"] = self.cfg.sequence_len + training_args_kwargs["max_completion_length"] = None + training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len + training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb if self.cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting @@ -1956,7 +1962,6 @@ def build(self, total_num_steps): training_args = self.build_training_arguments(total_num_steps) dpo_trainer_kwargs = {} if self.cfg.rl == "ipo": - dpo_trainer_kwargs["loss_type"] = "ipo" if self.cfg.dpo_label_smoothing: dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing if self.eval_dataset: @@ -1970,12 +1975,6 @@ def build(self, total_num_steps): if self.cfg.rl in ["dpo", "ipo"]: trainer_cls = AxolotlDPOTrainer trainer_cls_args = [self.model, self.model_ref] - - # these aren't used for the ORPO trainer - dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len - dpo_trainer_kwargs["max_target_length"] = None - dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len - dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model]