Skip to content

Commit

Permalink
move the updating of model config to the load_model_config function
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 15, 2023
1 parent 13483fe commit 539da6b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True
return AutoConfig.from_pretrained(
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)

return model_config


def load_tokenizer(cfg):
Expand Down Expand Up @@ -232,10 +237,6 @@ def load_model(
):
model_kwargs["use_flash_attention_2"] = True

if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)

try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
Expand Down

0 comments on commit 539da6b

Please sign in to comment.