diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 96be6ad45d..349dd9c017 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -73,7 +73,8 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb -def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: +def build_callback(name: str, kwargs: Union[DictConfig, Dict[str, + Any]]) -> Callback: if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -118,7 +119,7 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': if isinstance(kwargs, DictConfig): - kwargs = om.to_object(kwargs) + kwargs = om.to_object(kwargs) # pyright: ignore return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index a8b484bb24..237e27b52b 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -6,7 +6,6 @@ import pytest from composer.callbacks import Generate -from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 47f8408dce..73a027704c 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import MagicMock, patch from composer import Trainer from composer.loggers import MLFlowLogger