Skip to content

Commit

Permalink
Move MLFlow dataset outside of log_config (#1234)
Browse files Browse the repository at this point in the history
* move txt log

* typo

* Update scripts/train/train.py

* train config

* debug

* source data

* verbose

* debug

* debug

* check if mlflow is active

* fex tests

* move mlflow check to train

* update test

* precommit

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
KuuCi and dakinggg authored May 24, 2024
1 parent fdaa58b commit 1e4bd37
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
7 changes: 5 additions & 2 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'update_batch_size_info',
'process_init_device',
'log_config',
'log_dataset_uri',
]


Expand Down Expand Up @@ -508,7 +509,6 @@ def log_config(cfg: Dict[str, Any]) -> None:

if 'mlflow' in loggers and mlflow.active_run():
mlflow.log_params(params=cfg)
_log_dataset_uri(cfg)


def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]:
Expand Down Expand Up @@ -619,12 +619,15 @@ def _process_data_source(
log.warning('DataSource Not Found.')


def _log_dataset_uri(cfg: Dict[str, Any]) -> None:
def log_dataset_uri(cfg: Dict[str, Any]) -> None:
"""Logs dataset tracking information to MLflow.
Args:
cfg (DictConfig): A config dictionary of a run
"""
loggers = cfg.get('loggers', None) or {}
if 'mlflow' not in loggers or not mlflow.active_run():
return
# Figure out which data source to use
data_paths = _parse_source_dataset(cfg)

Expand Down
2 changes: 2 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TRAIN_CONFIG_KEYS,
TrainConfig,
log_config,
log_dataset_uri,
make_dataclass_and_log_config,
pop_config,
process_init_device,
Expand Down Expand Up @@ -530,6 +531,7 @@ def main(cfg: DictConfig) -> Trainer:
if train_cfg.log_config:
log.info('Logging config')
log_config(logged_cfg)
log_dataset_uri(logged_cfg)
torch.cuda.empty_cache()
gc.collect()

Expand Down
8 changes: 5 additions & 3 deletions tests/utils/test_mlflow_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import pytest

from llmfoundry.utils.config_utils import (
_log_dataset_uri,
_parse_source_dataset,
log_dataset_uri,
)

mlflow = pytest.importorskip('mlflow')
Expand Down Expand Up @@ -84,10 +84,12 @@ def test_log_dataset_uri():
}},
source_dataset_train='huggingface/train_dataset',
source_dataset_eval='huggingface/eval_dataset',
loggers={'mlflow': {}},
)

with patch('mlflow.log_input') as mock_log_input:
_log_dataset_uri(cfg)
with patch('mlflow.log_input') as mock_log_input, \
patch('mlflow.active_run', return_value=True):
log_dataset_uri(cfg)
assert mock_log_input.call_count == 2
meta_dataset_calls = [
args[0] for args, _ in mock_log_input.call_args_list
Expand Down

0 comments on commit 1e4bd37

Please sign in to comment.