diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 599144bd34..f05efe7b82 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1119,12 +1119,17 @@ def get_callbacks(self) -> List[TrainerCallback]: SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) ) if self.cfg.use_mlflow and is_mlflow_available(): + from transformers.integrations.integration_utils import MLflowCallback + from axolotl.utils.callbacks.mlflow_ import ( SaveAxolotlConfigtoMlflowCallback, ) - callbacks.append( - SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) + callbacks.extend( + [ + SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path), + MLflowCallback, + ] ) if self.cfg.use_comet and is_comet_available(): from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback diff --git a/src/axolotl/utils/mlflow_.py b/src/axolotl/utils/mlflow_.py index ce77390342..8710b07d06 100644 --- a/src/axolotl/utils/mlflow_.py +++ b/src/axolotl/utils/mlflow_.py @@ -16,3 +16,7 @@ def setup_mlflow_env_vars(cfg: DictDefault): # Enable mlflow if experiment name is present if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0: cfg.use_mlflow = True + + # Enable logging hf artifacts in mlflow if value is truthy + if cfg.hf_mlflow_log_artifacts is True: + os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true" diff --git a/tests/test_validation.py b/tests/test_validation.py index 6e0d0ad2a5..fb63977f5c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -13,6 +13,7 @@ from axolotl.utils.config import validate_config from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.dict import DictDefault +from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.models import check_model_config from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -1432,3 +1433,58 @@ def test_comet_sets_env(self, minimal_cfg): for key in comet_env.keys(): os.environ.pop(key, None) + + +class TestValidationMLflow(BaseValidation): + """ + Validation test for MLflow + """ + + def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg): + cfg = ( + DictDefault( + { + "hf_mlflow_log_artifacts": True, + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + + assert new_cfg.hf_mlflow_log_artifacts is True + + # Check it's not already present in env + assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ + + setup_mlflow_env_vars(new_cfg) + + assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true" + + os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None) + + def test_mlflow_not_used_by_default(self, minimal_cfg): + cfg = DictDefault({}) | minimal_cfg + + new_cfg = validate_config(cfg) + + setup_mlflow_env_vars(new_cfg) + + assert cfg.use_mlflow is not True + + cfg = ( + DictDefault( + { + "mlflow_experiment_name": "foo", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + + setup_mlflow_env_vars(new_cfg) + + assert new_cfg.use_mlflow is True + + os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)