Skip to content

Commit

Permalink
allow overriding of model_config parameters from the YML (#853)
Browse files Browse the repository at this point in the history
* allow overriding of model_config parameters from the YML

* remove old logging, update readme

* move the updating of model config to the load_model_config function

* add warning for deprecated rope_scaling in the root of the YML config
  • Loading branch information
winglian authored Nov 16, 2023
1 parent b3a61e8 commit 1bc1186
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 40 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,14 @@ is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:

# optional overrides to the base model configuration
model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float


# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
Expand Down Expand Up @@ -756,10 +764,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# LLaMA only
xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float

# Resume from a specific checkpoint dir
resume_from_checkpoint:
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ def validate_config(cfg):
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)

if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
57 changes: 21 additions & 36 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
AutoTokenizer,
BitsAndBytesConfig,
GPTQConfig,
LlamaConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
Expand All @@ -32,9 +31,14 @@
def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True
return AutoConfig.from_pretrained(
model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code
)
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)

return model_config


def load_tokenizer(cfg):
Expand All @@ -51,7 +55,7 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)

tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
Expand Down Expand Up @@ -110,7 +114,6 @@ def load_model(
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type
model_config = load_model_config(cfg)

Expand Down Expand Up @@ -238,16 +241,9 @@ def load_model(
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM

config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
Expand Down Expand Up @@ -305,66 +301,55 @@ def load_model(
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=cfg.trust_remote_code or False,
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(config, "max_seq_len")
and config.max_seq_len
and cfg.sequence_len > config.max_seq_len
hasattr(model_config, "max_seq_len")
and model_config.max_seq_len
and cfg.sequence_len > model_config.max_seq_len
):
config.max_seq_len = cfg.sequence_len
model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(config, "max_sequence_length")
and config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length
hasattr(model_config, "max_sequence_length")
and model_config.max_sequence_length
and cfg.sequence_len > model_config.max_sequence_length
):
config.max_sequence_length = cfg.sequence_len
model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
raise err

embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
Expand Down

0 comments on commit 1bc1186

Please sign in to comment.