diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b3f0b361d8..2de2f6b325 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -116,6 +116,13 @@ class InvalidConfigAccessError(KeyError): class PartialLlamaConfig(LlamaConfig): + """Holds the rope config for Llama models and throws + + an `InvalidConfigAccessError` if any other config elements + are read. This class is necessary because the + `LlamaRotaryEmbedding` class takes a full `LlamaConfig` now + instead of the old keyword arguments. + """ def __getattribute__(self, key: str): if key not in _ALLOWED_LLAMA_CONFIG_KEYS: @@ -129,9 +136,6 @@ def __getitem__(self, key: str): return super().__getitem__(key) - def _get_generation_defaults(self): - return {} - def gen_rotary_embedding( rope_impl: str,