Skip to content

Commit

Permalink
move deprecated kwargs from trainer to trainingargs (#2028)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Nov 10, 2024
1 parent e4af51e commit d356740
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1933,6 +1933,12 @@ def build_training_arguments(self, total_num_steps):

else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting

Expand All @@ -1956,7 +1962,6 @@ def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
Expand All @@ -1970,12 +1975,6 @@ def build(self, total_num_steps):
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref]

# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
Expand Down

0 comments on commit d356740

Please sign in to comment.