From 1e4bd3727cf2a7c6b0ce791120907f4d37069f02 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Fri, 24 May 2024 15:43:45 -0700 Subject: [PATCH] Move MLFlow dataset outside of log_config (#1234) * move txt log * typo * Update scripts/train/train.py * train config * debug * source data * verbose * debug * debug * check if mlflow is active * fex tests * move mlflow check to train * update test * precommit --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/utils/config_utils.py | 7 +++++-- scripts/train/train.py | 2 ++ tests/utils/test_mlflow_logging.py | 8 +++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 72ca19834b..b6a5acf6d9 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -39,6 +39,7 @@ 'update_batch_size_info', 'process_init_device', 'log_config', + 'log_dataset_uri', ] @@ -508,7 +509,6 @@ def log_config(cfg: Dict[str, Any]) -> None: if 'mlflow' in loggers and mlflow.active_run(): mlflow.log_params(params=cfg) - _log_dataset_uri(cfg) def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]: @@ -619,12 +619,15 @@ def _process_data_source( log.warning('DataSource Not Found.') -def _log_dataset_uri(cfg: Dict[str, Any]) -> None: +def log_dataset_uri(cfg: Dict[str, Any]) -> None: """Logs dataset tracking information to MLflow. Args: cfg (DictConfig): A config dictionary of a run """ + loggers = cfg.get('loggers', None) or {} + if 'mlflow' not in loggers or not mlflow.active_run(): + return # Figure out which data source to use data_paths = _parse_source_dataset(cfg) diff --git a/scripts/train/train.py b/scripts/train/train.py index e0c2b8a94f..bfeec14e0b 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -55,6 +55,7 @@ TRAIN_CONFIG_KEYS, TrainConfig, log_config, + log_dataset_uri, make_dataclass_and_log_config, pop_config, process_init_device, @@ -530,6 +531,7 @@ def main(cfg: DictConfig) -> Trainer: if train_cfg.log_config: log.info('Logging config') log_config(logged_cfg) + log_dataset_uri(logged_cfg) torch.cuda.empty_cache() gc.collect() diff --git a/tests/utils/test_mlflow_logging.py b/tests/utils/test_mlflow_logging.py index 04a600d44c..d2dd5f4689 100644 --- a/tests/utils/test_mlflow_logging.py +++ b/tests/utils/test_mlflow_logging.py @@ -7,8 +7,8 @@ import pytest from llmfoundry.utils.config_utils import ( - _log_dataset_uri, _parse_source_dataset, + log_dataset_uri, ) mlflow = pytest.importorskip('mlflow') @@ -84,10 +84,12 @@ def test_log_dataset_uri(): }}, source_dataset_train='huggingface/train_dataset', source_dataset_eval='huggingface/eval_dataset', + loggers={'mlflow': {}}, ) - with patch('mlflow.log_input') as mock_log_input: - _log_dataset_uri(cfg) + with patch('mlflow.log_input') as mock_log_input, \ + patch('mlflow.active_run', return_value=True): + log_dataset_uri(cfg) assert mock_log_input.call_count == 2 meta_dataset_calls = [ args[0] for args, _ in mock_log_input.call_args_list