From 4961436ff199837cdee223a1d47e7d0d258fb4fd Mon Sep 17 00:00:00 2001
From: Nicholas Garcia <nicholasgcgarcia@gmail.com>
Date: Tue, 23 Jan 2024 14:14:46 -0800
Subject: [PATCH] Allow bool input for loggers (#897)

* Allow bool input for loggers

* Convert earlier on

* Fix test case
---
 llmfoundry/utils/builders.py | 15 +++++----------
 scripts/train/train.py       |  3 ++-
 tests/utils/test_builders.py |  4 ++--
 3 files changed, 9 insertions(+), 13 deletions(-)

diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py
index 75438b895e..29642381f8 100644
--- a/llmfoundry/utils/builders.py
+++ b/llmfoundry/utils/builders.py
@@ -219,21 +219,16 @@ def build_callback(
 
 
 def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination:
-    kwargs_dict = {
-        k: v if isinstance(v, str) else om.to_container(v, resolve=True)
-        for k, v in kwargs.items()
-    }
-
     if name == 'wandb':
-        return WandBLogger(**kwargs_dict)
+        return WandBLogger(**kwargs)
     elif name == 'tensorboard':
-        return TensorboardLogger(**kwargs_dict)
+        return TensorboardLogger(**kwargs)
     elif name == 'in_memory_logger':
-        return InMemoryLogger(**kwargs_dict)
+        return InMemoryLogger(**kwargs)
     elif name == 'mlflow':
-        return MLFlowLogger(**kwargs_dict)
+        return MLFlowLogger(**kwargs)
     elif name == 'inmemory':
-        return InMemoryLogger(**kwargs_dict)
+        return InMemoryLogger(**kwargs)
     else:
         raise ValueError(f'Not sure how to build logger: {name}')
 
diff --git a/scripts/train/train.py b/scripts/train/train.py
index c3da1f1d3c..638ad8aaea 100644
--- a/scripts/train/train.py
+++ b/scripts/train/train.py
@@ -278,7 +278,8 @@ def main(cfg: DictConfig) -> Trainer:
     logger_configs: Optional[DictConfig] = pop_config(cfg,
                                                       'loggers',
                                                       must_exist=False,
-                                                      default_value=None)
+                                                      default_value=None,
+                                                      convert=True)
     callback_configs: Optional[DictConfig] = pop_config(cfg,
                                                         'callbacks',
                                                         must_exist=False,
diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py
index 9be6630075..303afc9b7d 100644
--- a/tests/utils/test_builders.py
+++ b/tests/utils/test_builders.py
@@ -135,14 +135,14 @@ def test_build_logger():
     with pytest.raises(ValueError):
         _ = build_logger('unknown', {})
 
-    logger_cfg = DictConfig({
+    logger_cfg = {
         'project': 'foobar',
         'init_kwargs': {
             'config': {
                 'foo': 'bar',
             }
         }
-    })
+    }
     wandb_logger = build_logger('wandb', logger_cfg)  # type: ignore
     assert isinstance(wandb_logger, WandBLogger)
     assert wandb_logger.project == 'foobar'