diff --git a/docs/config.qmd b/docs/config.qmd index 8329f35535..b6c0cb852a 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -266,6 +266,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step # mlflow configuration if you're using it mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name +mlflow_run_name: # Your run name hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry # Comet configuration if you're using it diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b1ee519dc4..9c12b6141a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1445,9 +1445,12 @@ def build(self, total_num_steps): report_to.append("comet_ml") training_arguments_kwargs["report_to"] = report_to - training_arguments_kwargs["run_name"] = ( - self.cfg.wandb_name if self.cfg.use_wandb else None - ) + if self.cfg.use_wandb: + training_arguments_kwargs["run_name"] = self.cfg.wandb_name + elif self.cfg.use_mlflow: + training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name + else: + training_arguments_kwargs["run_name"] = None training_arguments_kwargs["optim"] = ( self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1c33b59078..1a269b7982 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -447,6 +447,7 @@ class MLFlowConfig(BaseModel): use_mlflow: Optional[bool] = None mlflow_tracking_uri: Optional[str] = None mlflow_experiment_name: Optional[str] = None + mlflow_run_name: Optional[str] = None hf_mlflow_log_artifacts: Optional[bool] = None