diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 7dfaf8562b..0bcd587084 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -49,10 +49,10 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaConfig -from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import ( + LlamaConfig, + LlamaRotaryEmbedding, +) from llmfoundry.layers_registry import norms, param_init_fns from llmfoundry.models.layers.attention import ( @@ -166,7 +166,7 @@ def gen_rotary_embedding( num_attention_heads=n_heads, ) if rope_hf_config['type'] == 'no_scaling': - return HFRotaryEmbeddingFoundry(config=partial_llama_config) + return LlamaRotaryEmbeddingFoundry(config=partial_llama_config) elif rope_hf_config['type'] in {'llama3', 'linear', 'dynamic'}: return LlamaRotaryEmbedding(config=partial_llama_config) raise ValueError('rope_impl needs to be either dail or hf') @@ -341,7 +341,7 @@ def apply_sequence_id( return attn_bias -class HFRotaryEmbeddingFoundry(HFRotaryEmbedding): +class LlamaRotaryEmbeddingFoundry(LlamaRotaryEmbedding): @torch.no_grad() def forward( diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 45378e42bd..ed40e7a88a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -35,8 +35,7 @@ ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.bloom.modeling_bloom import build_alibi_tensor -from transformers.models.llama.modeling_llama import \ - LlamaRotaryEmbedding as HFRotaryEmbedding +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding from llmfoundry import ComposerHFCausalLM from llmfoundry.layers_registry import norms @@ -48,7 +47,7 @@ ) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM, MPTModel -from llmfoundry.models.mpt.modeling_mpt import HFRotaryEmbeddingFoundry +from llmfoundry.models.mpt.modeling_mpt import LlamaRotaryEmbeddingFoundry from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model from llmfoundry.utils.config_utils import to_dict_container @@ -2924,7 +2923,7 @@ def test_hf_rotary_child_class_builds(): list(range(max_seq_len)), ] * bsz) - rot_emb_mp = HFRotaryEmbeddingFoundry( + rot_emb_mp = LlamaRotaryEmbeddingFoundry( rope_head_dim, max_seq_len, rope_theta, @@ -2932,7 +2931,7 @@ def test_hf_rotary_child_class_builds(): ) cos_mp, sin_mp = rot_emb_mp(value, position_ids) - rot_emb = HFRotaryEmbedding( + rot_emb = LlamaRotaryEmbedding( rope_head_dim, max_seq_len, rope_theta,