From 4961436ff199837cdee223a1d47e7d0d258fb4fd Mon Sep 17 00:00:00 2001 From: Nicholas Garcia Date: Tue, 23 Jan 2024 14:14:46 -0800 Subject: [PATCH] Allow bool input for loggers (#897) * Allow bool input for loggers * Convert earlier on * Fix test case --- llmfoundry/utils/builders.py | 15 +++++---------- scripts/train/train.py | 3 ++- tests/utils/test_builders.py | 4 ++-- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 75438b895e..29642381f8 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -219,21 +219,16 @@ def build_callback( def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: - kwargs_dict = { - k: v if isinstance(v, str) else om.to_container(v, resolve=True) - for k, v in kwargs.items() - } - if name == 'wandb': - return WandBLogger(**kwargs_dict) + return WandBLogger(**kwargs) elif name == 'tensorboard': - return TensorboardLogger(**kwargs_dict) + return TensorboardLogger(**kwargs) elif name == 'in_memory_logger': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) elif name == 'mlflow': - return MLFlowLogger(**kwargs_dict) + return MLFlowLogger(**kwargs) elif name == 'inmemory': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) else: raise ValueError(f'Not sure how to build logger: {name}') diff --git a/scripts/train/train.py b/scripts/train/train.py index c3da1f1d3c..638ad8aaea 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -278,7 +278,8 @@ def main(cfg: DictConfig) -> Trainer: logger_configs: Optional[DictConfig] = pop_config(cfg, 'loggers', must_exist=False, - default_value=None) + default_value=None, + convert=True) callback_configs: Optional[DictConfig] = pop_config(cfg, 'callbacks', must_exist=False, diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 9be6630075..303afc9b7d 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -135,14 +135,14 @@ def test_build_logger(): with pytest.raises(ValueError): _ = build_logger('unknown', {}) - logger_cfg = DictConfig({ + logger_cfg = { 'project': 'foobar', 'init_kwargs': { 'config': { 'foo': 'bar', } } - }) + } wandb_logger = build_logger('wandb', logger_cfg) # type: ignore assert isinstance(wandb_logger, WandBLogger) assert wandb_logger.project == 'foobar'