diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 35b864c2d5..7dfaf8562b 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -155,7 +155,7 @@ def gen_rotary_embedding( ) elif rope_impl == 'hf': llama_rope_config = {**rope_hf_config} - llama_rope_config['rope_type'] = rope_hf_config.pop('type') + llama_rope_config['rope_type'] = llama_rope_config.pop('type') if llama_rope_config['rope_type'] == 'no_scaling': llama_rope_config['rope_type'] = 'default' partial_llama_config = PartialLlamaConfig(