diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a7ce93d236..b3f4636fb6 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -94,30 +94,41 @@ class InvalidConfigAccessError(KeyError): pass +_ALLOWED_LLAMA_CONFIG_KEYS = { + 'rope_scaling', + 'rope_theta', + 'max_position_embeddings', + 'hidden_size', + 'num_attention_heads', + '_get_generation_defaults', + 'label2id', + 'id2label', + 'torch_dtype', + 'problem_type', + '__class__', + 'partial_rotary_factor', +} + + class PartialLlamaConfig(LlamaConfig): - _ALLOWED_KEYS = { - 'rope_scaling', - 'rope_theta', - 'max_position_embeddings', - 'hidden_size', - 'num_attention_heads', - } def __getattribute__(self, key: str): - if key not in self._ALLOWED_KEYS: + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: raise InvalidConfigAccessError(key) return super().__getattribute__(key) def __getitem__(self, key: str): - if key not in self._ALLOWED_KEYS: + if key not in _ALLOWED_LLAMA_CONFIG_KEYS: raise InvalidConfigAccessError(key) return super().__getitem__(key) + def _get_generation_defaults(self): + return {} + def gen_rotary_embedding( - rope_head_dim: int, rope_impl: str, rope_theta: int, rope_dail_config: dict, @@ -126,6 +137,7 @@ def gen_rotary_embedding( d_model: int, n_heads: int, ): + rope_head_dim = d_model // n_heads if rope_impl == 'dail': return DAILRotaryEmbedding( dim=rope_head_dim, @@ -165,9 +177,11 @@ def gen_rotary_embedding( 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu ) elif rope_hf_config['type'] == 'llama3': + llama_rope_config = {**rope_hf_config} + llama_rope_config['rope_type'] = rope_hf_config.pop('type') return LlamaRotaryEmbedding( config=PartialLlamaConfig( - rope_scaling=rope_hf_config, + rope_scaling=llama_rope_config, rope_theta=rope_theta, max_position_embeddings=max_seq_len, hidden_size=d_model, @@ -439,7 +453,6 @@ def __init__(self, config: MPTConfig): if self.rope: self.rope_impl = config.attn_config['rope_impl'] self.rotary_embedding = gen_rotary_embedding( - rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 82b221ae95..4bfdfb84dc 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -251,7 +251,6 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), @@ -666,7 +665,6 @@ def gen_bias(attn_impl: str): rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( - rope_head_dim=cfg['d_model'] // cfg['n_heads'], rope_impl=pos_emb_config['rope_impl'], rope_theta=pos_emb_config['rope_theta'], rope_dail_config=pos_emb_config.get('rope_dail_config', {}), diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 00592caa77..34fb23f670 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -77,7 +77,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } dail_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=dail_rope_config['rope_impl'], rope_theta=dail_rope_config['rope_theta'], rope_dail_config=dail_rope_config['rope_dail_config'], @@ -94,7 +93,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): } hf_rope = gen_rotary_embedding( - rope_head_dim=cfg.d_model // cfg.n_heads, rope_impl=hf_rope_config['rope_impl'], rope_theta=hf_rope_config['rope_theta'], rope_dail_config={}, diff --git a/tests/models/test_rope_scaling.py b/tests/models/test_rope_scaling.py new file mode 100644 index 0000000000..076f0396b5 --- /dev/null +++ b/tests/models/test_rope_scaling.py @@ -0,0 +1,36 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + +rope_config = { + 'rope_theta': 500000.0, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'factor': 8.0, + 'low_freq_factor': 1.0, + 'high_freq_factor': 4.0, + 'original_max_position_embeddings': 8192, + 'type': 'llama3', + }, +} + +rope_dail_config = {} + + +def test_rope_scaling(): + d_model = 128 + n_heads = 32 + max_seq_len = 131_000 + + embedding = gen_rotary_embedding( + d_model=d_model, + n_heads=n_heads, + rope_dail_config=rope_dail_config, + max_seq_len=max_seq_len, + **rope_config, + ) + + assert isinstance(embedding, LlamaRotaryEmbedding)