diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 72ca19834b..b6a5acf6d9 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -39,6 +39,7 @@ 'update_batch_size_info', 'process_init_device', 'log_config', + 'log_dataset_uri', ] @@ -508,7 +509,6 @@ def log_config(cfg: Dict[str, Any]) -> None: if 'mlflow' in loggers and mlflow.active_run(): mlflow.log_params(params=cfg) - _log_dataset_uri(cfg) def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]: @@ -619,12 +619,15 @@ def _process_data_source( log.warning('DataSource Not Found.') -def _log_dataset_uri(cfg: Dict[str, Any]) -> None: +def log_dataset_uri(cfg: Dict[str, Any]) -> None: """Logs dataset tracking information to MLflow. Args: cfg (DictConfig): A config dictionary of a run """ + loggers = cfg.get('loggers', None) or {} + if 'mlflow' not in loggers or not mlflow.active_run(): + return # Figure out which data source to use data_paths = _parse_source_dataset(cfg) diff --git a/scripts/train/train.py b/scripts/train/train.py index e0c2b8a94f..bfeec14e0b 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -55,6 +55,7 @@ TRAIN_CONFIG_KEYS, TrainConfig, log_config, + log_dataset_uri, make_dataclass_and_log_config, pop_config, process_init_device, @@ -530,6 +531,7 @@ def main(cfg: DictConfig) -> Trainer: if train_cfg.log_config: log.info('Logging config') log_config(logged_cfg) + log_dataset_uri(logged_cfg) torch.cuda.empty_cache() gc.collect() diff --git a/tests/utils/test_mlflow_logging.py b/tests/utils/test_mlflow_logging.py index 04a600d44c..d2dd5f4689 100644 --- a/tests/utils/test_mlflow_logging.py +++ b/tests/utils/test_mlflow_logging.py @@ -7,8 +7,8 @@ import pytest from llmfoundry.utils.config_utils import ( - _log_dataset_uri, _parse_source_dataset, + log_dataset_uri, ) mlflow = pytest.importorskip('mlflow') @@ -84,10 +84,12 @@ def test_log_dataset_uri(): }}, source_dataset_train='huggingface/train_dataset', source_dataset_eval='huggingface/eval_dataset', + loggers={'mlflow': {}}, ) - with patch('mlflow.log_input') as mock_log_input: - _log_dataset_uri(cfg) + with patch('mlflow.log_input') as mock_log_input, \ + patch('mlflow.active_run', return_value=True): + log_dataset_uri(cfg) assert mock_log_input.call_count == 2 meta_dataset_calls = [ args[0] for args, _ in mock_log_input.call_args_list