From aa652f349676b12da3c6d0e6e1e67a4ce6f7286f Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Fri, 8 Dec 2023 09:42:07 -0800 Subject: [PATCH] build_loggers and add tests --- llmfoundry/utils/builders.py | 18 +++++++++++++----- scripts/eval/eval.py | 2 +- scripts/train/train.py | 2 +- tests/utils/test_builders.py | 30 ++++++++++++++++++++++++++++-- 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 23533238e8..0aa41ca153 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -218,16 +218,24 @@ def build_callback( def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: + for k, v in kwargs.items(): + print(k, v, type(k), type(v)) + + kwargs_dict = { + k: v if isinstance(v, str) else om.to_container(v, resolve=True) + for k, v in kwargs.items() + } + if name == 'wandb': - return WandBLogger(**kwargs) + return WandBLogger(**kwargs_dict) elif name == 'tensorboard': - return TensorboardLogger(**kwargs) + return TensorboardLogger(**kwargs_dict) elif name == 'in_memory_logger': - return InMemoryLogger(**kwargs) + return InMemoryLogger(**kwargs_dict) elif name == 'mlflow': - return MLFlowLogger(**kwargs) + return MLFlowLogger(**kwargs_dict) elif name == 'inmemory': - return InMemoryLogger(**kwargs) + return InMemoryLogger(**kwargs_dict) else: raise ValueError(f'Not sure how to build logger: {name}') diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 194bba05d0..369a894720 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -140,7 +140,7 @@ def evaluate_model( callbacks.append(eval_gauntlet_callback) loggers: List[LoggerDestination] = [ - build_logger(name, om.to_container(logger_cfg, resolve=True)) + build_logger(name, logger_cfg) for name, logger_cfg in loggers_cfg.items() ] diff --git a/scripts/train/train.py b/scripts/train/train.py index 466e288415..58987e4e62 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -453,7 +453,7 @@ def main(cfg: DictConfig) -> Trainer: # Loggers loggers = [ - build_logger(str(name), om.to_container(logger_cfg, resolve=True)) + build_logger(str(name), logger_cfg) for name, logger_cfg in logger_configs.items() ] if logger_configs else [] diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 20b2c4669c..9be6630075 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -12,6 +12,7 @@ import torch.nn as nn from composer.callbacks import Generate from composer.core import Evaluator +from composer.loggers import WandBLogger from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase @@ -20,8 +21,8 @@ from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_callback, build_eval_loaders, - build_evaluators, build_optimizer, - build_tokenizer) + build_evaluators, build_logger, + build_optimizer, build_tokenizer) @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -130,6 +131,31 @@ def test_build_hf_checkpointer_callback(): assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict +def test_build_logger(): + with pytest.raises(ValueError): + _ = build_logger('unknown', {}) + + logger_cfg = DictConfig({ + 'project': 'foobar', + 'init_kwargs': { + 'config': { + 'foo': 'bar', + } + } + }) + wandb_logger = build_logger('wandb', logger_cfg) # type: ignore + assert isinstance(wandb_logger, WandBLogger) + assert wandb_logger.project == 'foobar' + + # confirm the typing conversion from DictConfig to dict, + # wandb.init() will fail if config is not explicitly + # dict type + ik = wandb_logger._init_kwargs + assert ik == {'config': {'foo': 'bar'}, 'project': 'foobar'} + assert isinstance(ik, dict) + assert isinstance(ik['config'], dict) + + class _DummyModule(nn.Module): def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32):