Skip to content

Commit

Permalink
fix optim_args
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Mar 16, 2024
1 parent 974933d commit dc45927
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 8 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
Expand Down Expand Up @@ -1018,7 +1018,13 @@ def build(self, total_num_steps):
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
training_arguments_kwargs["optim_args"] = self.cfg.optim_args
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
optim_args: Optional[Dict[str, Any]] = Field(
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
)
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
Expand Down

0 comments on commit dc45927

Please sign in to comment.