From 000a735c275f8cb1ea2b2871f0e02455842ff02a Mon Sep 17 00:00:00 2001 From: Nicholas Garcia Date: Mon, 22 Jan 2024 17:02:40 -0800 Subject: [PATCH 1/3] Allow bool input for loggers --- llmfoundry/utils/builders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 75438b895e..3d450eec76 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -220,7 +220,7 @@ def build_callback( def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: kwargs_dict = { - k: v if isinstance(v, str) else om.to_container(v, resolve=True) + k: v if isinstance(v, str) or isinstance(v, bool) else om.to_container(v, resolve=True) for k, v in kwargs.items() } From 9b1ee4c8d309eabd8b1f608f19457c51be85e9c7 Mon Sep 17 00:00:00 2001 From: Nicholas Garcia Date: Tue, 23 Jan 2024 12:48:48 -0800 Subject: [PATCH 2/3] Convert earlier on --- llmfoundry/utils/builders.py | 15 +++++---------- scripts/train/train.py | 3 ++- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 3d450eec76..29642381f8 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -219,21 +219,16 @@ def build_callback( def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: - kwargs_dict = { - k: v if isinstance(v, str) or isinstance(v, bool) else om.to_container(v, resolve=True) - for k, v in kwargs.items() - } - if name == 'wandb': - return WandBLogger(**kwargs_dict) + return WandBLogger(**kwargs) elif name == 'tensorboard': - return TensorboardLogger(**kwargs_dict) + return TensorboardLogger(**kwargs) elif name == 'in_memory_logger': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) elif name == 'mlflow': - return MLFlowLogger(**kwargs_dict) + return MLFlowLogger(**kwargs) elif name == 'inmemory': - return InMemoryLogger(**kwargs_dict) + return InMemoryLogger(**kwargs) else: raise ValueError(f'Not sure how to build logger: {name}') diff --git a/scripts/train/train.py b/scripts/train/train.py index c3da1f1d3c..638ad8aaea 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -278,7 +278,8 @@ def main(cfg: DictConfig) -> Trainer: logger_configs: Optional[DictConfig] = pop_config(cfg, 'loggers', must_exist=False, - default_value=None) + default_value=None, + convert=True) callback_configs: Optional[DictConfig] = pop_config(cfg, 'callbacks', must_exist=False, From e037e525da63b1af2bc5a5b548300d60d32d7339 Mon Sep 17 00:00:00 2001 From: Nicholas Garcia Date: Tue, 23 Jan 2024 13:31:29 -0800 Subject: [PATCH 3/3] Fix test case --- tests/utils/test_builders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 9be6630075..303afc9b7d 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -135,14 +135,14 @@ def test_build_logger(): with pytest.raises(ValueError): _ = build_logger('unknown', {}) - logger_cfg = DictConfig({ + logger_cfg = { 'project': 'foobar', 'init_kwargs': { 'config': { 'foo': 'bar', } } - }) + } wandb_logger = build_logger('wandb', logger_cfg) # type: ignore assert isinstance(wandb_logger, WandBLogger) assert wandb_logger.project == 'foobar'