Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump Version to 0.10.0.dev0 #1255

Merged
merged 22 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@
'utils',
]

__version__ = '0.9.0.dev0'
__version__ = '0.10.0.dev0'
13 changes: 8 additions & 5 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class EvalConfig:
# Logging parameters
python_log_level: Optional[str] = 'debug'
loggers: Optional[Dict[str, Any]] = None
console_log_interval: Union[int, str] = '1ba'
log_config: bool = True

# Model/run parameters
Expand Down Expand Up @@ -180,6 +181,11 @@ class TrainConfig:
# Variables to ignore
variables: Optional[Dict[str, Any]] = None

# Fields created by `update_batch_size_info`
n_gpus: int = MISSING
device_train_batch_size: int = MISSING
device_train_grad_accum: str = MISSING


TRAIN_CONFIG_KEYS = {field.name for field in fields(TrainConfig)}

Expand Down Expand Up @@ -242,7 +248,6 @@ def make_dataclass_and_log_config(
icl_tasks_required: bool = False,
) -> Tuple[Dict[str, Any], T]:
"""Converts a DictConfig to a dataclass and creates a logged config."""
# Resolve all interpolation variables as early as possible
unstructured_config = om.to_container(cfg, resolve=True)
assert isinstance(unstructured_config, dict)
assert all(isinstance(k, str) for k in unstructured_config.keys())
Expand Down Expand Up @@ -289,11 +294,9 @@ def make_dataclass_and_log_config(
unstructured_config['variables'] = {}

for key in extraneous_keys:
warnings.warn(
f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Interpreting {key} as a variable for logging purposes. Top-level variables are deprecated and will not be supported in future releases. Please place any variables under the `variables` key.',
category=DeprecationWarning,
raise ValueError(
f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Please place any variables under the `variables` key.',
)
unstructured_config['variables'][key] = unstructured_config.pop(key)

dataclass_dict_config: DictConfig = om.structured(
dataclass_constructor(**unstructured_config),
Expand Down
1 change: 0 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,5 @@ def main(cfg: DictConfig) -> Trainer:
yaml_cfg = om.load(f)
cli_cfg = om.from_cli(args_list)
cfg = om.merge(yaml_cfg, cli_cfg)
om.resolve(cfg)
assert isinstance(cfg, DictConfig)
main(cfg)
10 changes: 9 additions & 1 deletion tests/a_scripts/eval/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from llmfoundry.utils import build_tokenizer
from llmfoundry.utils.builders import build_composer_model
from llmfoundry.utils.config_utils import to_dict_container
from llmfoundry.utils.config_utils import EVAL_CONFIG_KEYS, to_dict_container
from scripts.eval.eval import main # noqa: E402
from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg

Expand Down Expand Up @@ -134,6 +134,14 @@ def test_loader_eval(
test_cfg.eval_interval = '1ba'
test_cfg.loggers = om.DictConfig({'inmemory': om.DictConfig({})})

# This test uses a training yaml with training-only keys present.
# We exclude these keys before calling `main` from the eval script.
allowed_keys = EVAL_CONFIG_KEYS
present_keys = set(test_cfg.keys())
keys_to_pop = present_keys.difference(allowed_keys)

[test_cfg.pop(key) for key in keys_to_pop]

trainers, eval_gauntlet_df = main(test_cfg)

assert eval_gauntlet_df is None
Expand Down
15 changes: 4 additions & 11 deletions tests/a_scripts/eval/test_eval_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import warnings

import omegaconf
import pytest
Expand Down Expand Up @@ -42,12 +41,13 @@ def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None:
omegaconf.errors.InterpolationKeyError,
omegaconf.errors.MissingMandatoryValue,
TypeError,
ValueError,
)):
cfg[p + '-mispelled'] = cfg.pop(p)
main(cfg)
cfg[p] = cfg.pop(p + '-mispelled')

def test_optional_mispelled_params_raise_warning(
def test_optional_mispelled_params_raise_error(
self,
cfg: DictConfig,
) -> None:
Expand All @@ -67,15 +67,8 @@ def test_optional_mispelled_params_raise_warning(
orig_value = cfg.pop(param, None)
updated_param = param + '-mispelling'
cfg[updated_param] = orig_value
with warnings.catch_warnings(record=True) as warning_list:
try:
main(cfg)
except:
pass
assert any(
f'Unused parameter {updated_param} found in cfg.' in
str(warning.message) for warning in warning_list
)
with pytest.raises(ValueError):
main(cfg)
# restore configs.
cfg = copy.deepcopy(old_cfg)

Expand Down
18 changes: 6 additions & 12 deletions tests/a_scripts/train/test_train_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import copy
import json
import os
import warnings

import omegaconf
import pytest
Expand Down Expand Up @@ -63,7 +62,9 @@ def cfg(self, foundry_dir: str) -> DictConfig:
def test_misspelled_mandatory_params_fail(self, cfg: DictConfig) -> None:
"""Check that mandatory misspelled inputs fail to train."""
cfg.trai_loader = cfg.pop('train_loader')
with pytest.raises((omegaconf.errors.MissingMandatoryValue, TypeError)):
with pytest.raises(
(omegaconf.errors.MissingMandatoryValue, TypeError, ValueError),
):
main(cfg)

def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None:
Expand All @@ -89,7 +90,7 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None:
main(cfg)
cfg[param] = orig_param

def test_optional_misspelled_params_raise_warning(
def test_optional_misspelled_params_raise_error(
self,
cfg: DictConfig,
) -> None:
Expand All @@ -113,15 +114,8 @@ def test_optional_misspelled_params_raise_warning(
orig_value = cfg.pop(param, None)
updated_param = param + '-misspelling'
cfg[updated_param] = orig_value
with warnings.catch_warnings(record=True) as warning_list:
try:
main(cfg)
except:
pass
assert any(
f'Unused parameter {updated_param} found in cfg.' in
str(warning.message) for warning in warning_list
)
with pytest.raises(ValueError):
main(cfg)
# restore configs.
cfg = copy.deepcopy(old_cfg)

Expand Down
Loading