diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 6b63cc669f..e5298ee46c 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -236,6 +236,22 @@ def to_container( return cfg # type: ignore (dicts and lists are already in the correct format) +def move_variables_to_section(cfg: Dict[str, Any]) -> Dict[str, Any]: + """Moves variables from the root of the config to the variables section. + + This function is used to move variables from the root of the config to the + variables section. This is necessary because the variables section is + required for interpolation. + """ + if 'variables' not in cfg: + cfg['variables'] = {} + for key in list(cfg.keys()): + if 'variables.' in key: + _, suffix = key.split('.', 1) + cfg['variables'][suffix] = cfg.pop(key) + return cfg + + T = TypeVar('T') @@ -249,6 +265,8 @@ def make_dataclass_and_log_config( ) -> Tuple[Dict[str, Any], T]: """Converts a DictConfig to a dataclass and creates a logged config.""" # Resolve all interpolation variables as early as possible + # before resolving, convert keys of the form `variables.key` to `variables[key]`. + cfg = move_variables_to_section(cfg) unstructured_config = om.to_container(cfg, resolve=True) assert isinstance(unstructured_config, dict) assert all(isinstance(k, str) for k in unstructured_config.keys())