Skip to content

Commit

Permalink
build_loggers and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Dec 8, 2023
1 parent 9e11cf7 commit aa652f3
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 9 deletions.
18 changes: 13 additions & 5 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,24 @@ def build_callback(


def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination:
for k, v in kwargs.items():
print(k, v, type(k), type(v))

kwargs_dict = {
k: v if isinstance(v, str) else om.to_container(v, resolve=True)
for k, v in kwargs.items()
}

if name == 'wandb':
return WandBLogger(**kwargs)
return WandBLogger(**kwargs_dict)
elif name == 'tensorboard':
return TensorboardLogger(**kwargs)
return TensorboardLogger(**kwargs_dict)
elif name == 'in_memory_logger':
return InMemoryLogger(**kwargs)
return InMemoryLogger(**kwargs_dict)
elif name == 'mlflow':
return MLFlowLogger(**kwargs)
return MLFlowLogger(**kwargs_dict)
elif name == 'inmemory':
return InMemoryLogger(**kwargs)
return InMemoryLogger(**kwargs_dict)
else:
raise ValueError(f'Not sure how to build logger: {name}')

Expand Down
2 changes: 1 addition & 1 deletion scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def evaluate_model(
callbacks.append(eval_gauntlet_callback)

loggers: List[LoggerDestination] = [
build_logger(name, om.to_container(logger_cfg, resolve=True))
build_logger(name, logger_cfg)
for name, logger_cfg in loggers_cfg.items()
]

Expand Down
2 changes: 1 addition & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def main(cfg: DictConfig) -> Trainer:

# Loggers
loggers = [
build_logger(str(name), om.to_container(logger_cfg, resolve=True))
build_logger(str(name), logger_cfg)
for name, logger_cfg in logger_configs.items()
] if logger_configs else []

Expand Down
30 changes: 28 additions & 2 deletions tests/utils/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn as nn
from composer.callbacks import Generate
from composer.core import Evaluator
from composer.loggers import WandBLogger
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from transformers import PreTrainedTokenizerBase
Expand All @@ -20,8 +21,8 @@
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_callback, build_eval_loaders,
build_evaluators, build_optimizer,
build_tokenizer)
build_evaluators, build_logger,
build_optimizer, build_tokenizer)


@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [
Expand Down Expand Up @@ -130,6 +131,31 @@ def test_build_hf_checkpointer_callback():
assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict


def test_build_logger():
with pytest.raises(ValueError):
_ = build_logger('unknown', {})

logger_cfg = DictConfig({
'project': 'foobar',
'init_kwargs': {
'config': {
'foo': 'bar',
}
}
})
wandb_logger = build_logger('wandb', logger_cfg) # type: ignore
assert isinstance(wandb_logger, WandBLogger)
assert wandb_logger.project == 'foobar'

# confirm the typing conversion from DictConfig to dict,
# wandb.init() will fail if config is not explicitly
# dict type
ik = wandb_logger._init_kwargs
assert ik == {'config': {'foo': 'bar'}, 'project': 'foobar'}
assert isinstance(ik, dict)
assert isinstance(ik['config'], dict)


class _DummyModule(nn.Module):

def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32):
Expand Down

0 comments on commit aa652f3

Please sign in to comment.