diff --git a/scripts/train/train.py b/scripts/train/train.py index 60ee55955e..88f776375f 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -383,6 +383,12 @@ def main(cfg: DictConfig) -> Trainer: 'compile_config', must_exist=False, default_value=None) + metadata: Optional[Dict[str, str]] = pop_config(cfg, + 'metadata', + must_exist=False, + default_value=None, + convert=True) + # Enable autoresume from model checkpoints if possible autoresume_default: bool = False if logged_cfg.get('run_name', None) is not None \ @@ -460,6 +466,14 @@ def main(cfg: DictConfig) -> Trainer: mosaicml_logger = MosaicMLLogger() loggers.append(mosaicml_logger) + if metadata is not None: + # Flatten the metadata for logging + logged_cfg.pop('metadata', None) + logged_cfg.update(metadata, merge=True) + if mosaicml_logger is not None: + mosaicml_logger.log_metrics(metadata) + mosaicml_logger._flush_metadata(force_flush=True) + # Profiling profiler: Optional[Profiler] = None profiler_cfg: Optional[DictConfig] = pop_config(cfg,