Skip to content

Commit

Permalink
Log checkpoints as mlflow artifacts (#1976)
Browse files Browse the repository at this point in the history
* Ensure hf_mlflow_log_artifact config var is set in env

* Add transformer MLflowCallback to callbacks list when mlflow enabled

* Test hf_mlflow_log_artifacts is set correctly

* Test mlflow not being used by default
  • Loading branch information
awhazell authored Oct 22, 2024
1 parent 5c629ee commit 9bd5f7d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,12 +1119,17 @@ def get_callbacks(self) -> List[TrainerCallback]:
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback

from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)

callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/utils/mlflow_.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ def setup_mlflow_env_vars(cfg: DictDefault):
# Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True

# Enable logging hf artifacts in mlflow if value is truthy
if cfg.hf_mlflow_log_artifacts is True:
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"
56 changes: 56 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars

Expand Down Expand Up @@ -1432,3 +1433,58 @@ def test_comet_sets_env(self, minimal_cfg):

for key in comet_env.keys():
os.environ.pop(key, None)


class TestValidationMLflow(BaseValidation):
"""
Validation test for MLflow
"""

def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg):
cfg = (
DictDefault(
{
"hf_mlflow_log_artifacts": True,
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)

assert new_cfg.hf_mlflow_log_artifacts is True

# Check it's not already present in env
assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ

setup_mlflow_env_vars(new_cfg)

assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true"

os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None)

def test_mlflow_not_used_by_default(self, minimal_cfg):
cfg = DictDefault({}) | minimal_cfg

new_cfg = validate_config(cfg)

setup_mlflow_env_vars(new_cfg)

assert cfg.use_mlflow is not True

cfg = (
DictDefault(
{
"mlflow_experiment_name": "foo",
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)

setup_mlflow_env_vars(new_cfg)

assert new_cfg.use_mlflow is True

os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)

0 comments on commit 9bd5f7d

Please sign in to comment.