From 1993e7fcef3239bd418dc549b440f6345ecf8325 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:01:55 -0700 Subject: [PATCH] Error if metadata matches existing keys (#1313) --- llmfoundry/utils/config_utils.py | 1 + scripts/train/train.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index f977a27213..1fc4a0e96e 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -173,6 +173,7 @@ class TrainConfig: # Metadata metadata: Optional[Dict[str, Any]] = None + flatten_metadata: bool = True run_name: Optional[str] = None # Resumption diff --git a/scripts/train/train.py b/scripts/train/train.py index 655b5de938..134058a595 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -320,9 +320,21 @@ def main(cfg: DictConfig) -> Trainer: loggers.append(mosaicml_logger) if train_cfg.metadata is not None: - # Flatten the metadata for logging - logged_cfg.pop('metadata', None) - logged_cfg.update(train_cfg.metadata, merge=True) + # Optionally flatten the metadata for logging + if train_cfg.flatten_metadata: + logged_cfg.pop('metadata', None) + common_keys = set( + logged_cfg.keys(), + ) & set(train_cfg.metadata.keys()) + if len(common_keys) > 0: + raise ValueError( + f'Keys {common_keys} are already present in the config. Please rename them in metadata ' + + + 'or set flatten_metadata=False to avoid flattening the metadata in the logged config.', + ) + + logged_cfg.update(train_cfg.metadata, merge=True) + if mosaicml_logger is not None: mosaicml_logger.log_metrics(train_cfg.metadata) mosaicml_logger._flush_metadata(force_flush=True)