From 7de912e09756521f3561566b0a4d18a4c0a4988e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 26 Feb 2024 14:24:28 -0500 Subject: [PATCH] hotfix for capabilities loading (#1331) --- src/axolotl/cli/__init__.py | 14 +++++++------- src/axolotl/utils/config/__init__.py | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index a156342474..abca478e4d 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -30,7 +30,6 @@ from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.config import ( - GPUCapabilities, normalize_cfg_datasets, normalize_config, validate_config, @@ -350,14 +349,15 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): except: # pylint: disable=bare-except # noqa: E722 gpu_version = None - capabilities = GPUCapabilities( - bf16=is_torch_bf16_gpu_available(), - n_gpu=os.environ.get("WORLD_SIZE", 1), - compute_capability=gpu_version, + cfg = validate_config( + cfg, + capabilities={ + "bf16": is_torch_bf16_gpu_available(), + "n_gpu": os.environ.get("WORLD_SIZE", 1), + "compute_capability": gpu_version, + }, ) - cfg = validate_config(cfg, capabilities=capabilities) - prepare_optim_env(cfg) normalize_config(cfg) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index b21db31760..6635ff8e2a 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -13,7 +13,6 @@ AxolotlConfigWCapabilities, AxolotlInputConfig, ) -from axolotl.utils.config.models.internals import GPUCapabilities from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model_config @@ -197,7 +196,7 @@ def normalize_cfg_datasets(cfg): cfg.datasets[idx].conversation = "chatml" -def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None): +def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): if capabilities: return DictDefault( dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))