diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ccd9d37c0d..cc162d210a 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -692,6 +692,9 @@ def build(self, total_num_steps): and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine" ) + training_arguments_kwargs["lr_scheduler_kwargs"] = ( + self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {} + ) training_arguments_kwargs["weight_decay"] = ( self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0 )