diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 5e52dafabf..260aba747f 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1087,13 +1087,21 @@ def build_training_arguments(self, total_num_steps): "use_reentrant": False } + # set save_strategy and save_steps + if self.cfg.save_steps: + training_args_kwargs["save_strategy"] = "steps" + training_args_kwargs["save_steps"] = self.cfg.save_steps + elif self.cfg.save_strategy: + training_args_kwargs["save_strategy"] = self.cfg.save_strategy + else: + # default to saving each epoch if not defined + training_args_kwargs["save_strategy"] = "epoch" + training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, - save_strategy="steps", - save_steps=self.cfg.save_steps, output_dir=self.cfg.output_dir, warmup_steps=self.cfg.warmup_steps, logging_first_step=True,