diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index c9666566bf..5e2795f9c9 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -71,4 +71,4 @@ 'utils', ] -__version__ = '0.9.0.dev0' +__version__ = '0.10.0.dev0' diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index b6a5acf6d9..5ab148bbe8 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -67,6 +67,7 @@ class EvalConfig: # Logging parameters python_log_level: Optional[str] = 'debug' loggers: Optional[Dict[str, Any]] = None + console_log_interval: Union[int, str] = '1ba' log_config: bool = True # Model/run parameters @@ -180,6 +181,11 @@ class TrainConfig: # Variables to ignore variables: Optional[Dict[str, Any]] = None + # Fields created by `update_batch_size_info` + n_gpus: int = MISSING + device_train_batch_size: int = MISSING + device_train_grad_accum: str = MISSING + TRAIN_CONFIG_KEYS = {field.name for field in fields(TrainConfig)} @@ -242,7 +248,6 @@ def make_dataclass_and_log_config( icl_tasks_required: bool = False, ) -> Tuple[Dict[str, Any], T]: """Converts a DictConfig to a dataclass and creates a logged config.""" - # Resolve all interpolation variables as early as possible unstructured_config = om.to_container(cfg, resolve=True) assert isinstance(unstructured_config, dict) assert all(isinstance(k, str) for k in unstructured_config.keys()) @@ -289,11 +294,9 @@ def make_dataclass_and_log_config( unstructured_config['variables'] = {} for key in extraneous_keys: - warnings.warn( - f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Interpreting {key} as a variable for logging purposes. Top-level variables are deprecated and will not be supported in future releases. Please place any variables under the `variables` key.', - category=DeprecationWarning, + raise ValueError( + f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Please place any variables under the `variables` key.', ) - unstructured_config['variables'][key] = unstructured_config.pop(key) dataclass_dict_config: DictConfig = om.structured( dataclass_constructor(**unstructured_config), diff --git a/scripts/train/train.py b/scripts/train/train.py index 3cf3d9551d..f2a70b526d 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -553,6 +553,5 @@ def main(cfg: DictConfig) -> Trainer: yaml_cfg = om.load(f) cli_cfg = om.from_cli(args_list) cfg = om.merge(yaml_cfg, cli_cfg) - om.resolve(cfg) assert isinstance(cfg, DictConfig) main(cfg) diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index a56778538c..01f3760d26 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -13,7 +13,7 @@ from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model -from llmfoundry.utils.config_utils import to_dict_container +from llmfoundry.utils.config_utils import EVAL_CONFIG_KEYS, to_dict_container from scripts.eval.eval import main # noqa: E402 from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg @@ -134,6 +134,14 @@ def test_loader_eval( test_cfg.eval_interval = '1ba' test_cfg.loggers = om.DictConfig({'inmemory': om.DictConfig({})}) + # This test uses a training yaml with training-only keys present. + # We exclude these keys before calling `main` from the eval script. + allowed_keys = EVAL_CONFIG_KEYS + present_keys = set(test_cfg.keys()) + keys_to_pop = present_keys.difference(allowed_keys) + + [test_cfg.pop(key) for key in keys_to_pop] + trainers, eval_gauntlet_df = main(test_cfg) assert eval_gauntlet_df is None diff --git a/tests/a_scripts/eval/test_eval_inputs.py b/tests/a_scripts/eval/test_eval_inputs.py index 98b15743b3..0ca5765a26 100644 --- a/tests/a_scripts/eval/test_eval_inputs.py +++ b/tests/a_scripts/eval/test_eval_inputs.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import copy import os -import warnings import omegaconf import pytest @@ -42,12 +41,13 @@ def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None: omegaconf.errors.InterpolationKeyError, omegaconf.errors.MissingMandatoryValue, TypeError, + ValueError, )): cfg[p + '-mispelled'] = cfg.pop(p) main(cfg) cfg[p] = cfg.pop(p + '-mispelled') - def test_optional_mispelled_params_raise_warning( + def test_optional_mispelled_params_raise_error( self, cfg: DictConfig, ) -> None: @@ -67,15 +67,8 @@ def test_optional_mispelled_params_raise_warning( orig_value = cfg.pop(param, None) updated_param = param + '-mispelling' cfg[updated_param] = orig_value - with warnings.catch_warnings(record=True) as warning_list: - try: - main(cfg) - except: - pass - assert any( - f'Unused parameter {updated_param} found in cfg.' in - str(warning.message) for warning in warning_list - ) + with pytest.raises(ValueError): + main(cfg) # restore configs. cfg = copy.deepcopy(old_cfg) diff --git a/tests/a_scripts/train/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py index 5a3b21dc3b..5901d53e94 100644 --- a/tests/a_scripts/train/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -3,7 +3,6 @@ import copy import json import os -import warnings import omegaconf import pytest @@ -63,7 +62,9 @@ def cfg(self, foundry_dir: str) -> DictConfig: def test_misspelled_mandatory_params_fail(self, cfg: DictConfig) -> None: """Check that mandatory misspelled inputs fail to train.""" cfg.trai_loader = cfg.pop('train_loader') - with pytest.raises((omegaconf.errors.MissingMandatoryValue, TypeError)): + with pytest.raises( + (omegaconf.errors.MissingMandatoryValue, TypeError, ValueError), + ): main(cfg) def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: @@ -89,7 +90,7 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: main(cfg) cfg[param] = orig_param - def test_optional_misspelled_params_raise_warning( + def test_optional_misspelled_params_raise_error( self, cfg: DictConfig, ) -> None: @@ -113,15 +114,8 @@ def test_optional_misspelled_params_raise_warning( orig_value = cfg.pop(param, None) updated_param = param + '-misspelling' cfg[updated_param] = orig_value - with warnings.catch_warnings(record=True) as warning_list: - try: - main(cfg) - except: - pass - assert any( - f'Unused parameter {updated_param} found in cfg.' in - str(warning.message) for warning in warning_list - ) + with pytest.raises(ValueError): + main(cfg) # restore configs. cfg = copy.deepcopy(old_cfg)