From e050495948ecfa52fc3bec98c9337ef82bd0fbab Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Mon, 4 Sep 2023 00:19:03 -0400 Subject: [PATCH] move is_llama_derived_model into normalize_config (#524) --- scripts/finetune.py | 11 +---------- src/axolotl/utils/config.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 0a5f318639..b998edc798 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -24,7 +24,7 @@ from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process -from axolotl.utils.models import load_model_config, load_tokenizer +from axolotl.utils.models import load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.wandb import setup_wandb_env_vars @@ -216,15 +216,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): else: cfg[k] = kwargs[k] - model_config = load_model_config(cfg) - - # figure out if the model is llama - cfg.is_llama_derived_model = ( - (hasattr(model_config, "model_type") and model_config.model_type == "llama") - or cfg.is_llama_derived_model - or "llama" in cfg.base_model - or (cfg.model_type and "llama" in cfg.model_type.lower()) - ) validate_config(cfg) normalize_config(cfg) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index abb3154d21..93a23f7738 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -6,6 +6,7 @@ import torch from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.models import load_model_config LOG = logging.getLogger("axolotl") @@ -69,6 +70,16 @@ def normalize_config(cfg): else: cfg.torch_dtype = torch.float32 + model_config = load_model_config(cfg) + + # figure out if the model is llama + cfg.is_llama_derived_model = ( + (hasattr(model_config, "model_type") and model_config.model_type == "llama") + or cfg.is_llama_derived_model + or "llama" in cfg.base_model + or (cfg.model_type and "llama" in cfg.model_type.lower()) + ) + log_gpu_memory_usage(LOG, "baseline", cfg.device)