diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 93d8cbef74..ff33990d7a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -826,7 +826,8 @@ def set_output_embeddings( self.transformer.set_input_embeddings(new_embeddings) def tie_weights(self) -> None: - self.lm_head = None + if getattr(self.config, 'tie_word_embeddings', True): + self.lm_head = None def set_decoder(self, decoder: MPTModel) -> None: self.transformer = decoder diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index b47f267c55..d007850b68 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -49,6 +49,31 @@ def test_remote_code_false_mpt( tokenizer) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_tie_weights(tie_word_embeddings: bool): + # Test that the tie_weights function sets lm_head correctly + hf_config = MPTConfig(init_device='cpu', + d_model=128, + n_heads=4, + n_layers=2, + expansion_ratio=2, + max_seq_len=2048, + attn_config={ + 'attn_impl': 'torch', + }, + no_bias=True, + tie_word_embeddings=tie_word_embeddings) + + mpt = MPTForCausalLM(hf_config) + + assert mpt.config.tie_word_embeddings == tie_word_embeddings + mpt.tie_weights() + if tie_word_embeddings: + assert mpt.lm_head is None + else: + assert mpt.lm_head is not None + + @pytest.mark.parametrize('model_cfg_overrides', [ { 'max_seq_len': 1024 diff --git a/tests/models/test_model.py b/tests/models/test_model.py index a3a7fb3814..ceab83ac74 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1222,6 +1222,7 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, 'transformer.blocks.0': 0, 'transformer.blocks.1': 1 if world_size == 2 else 0, 'transformer.norm_f': 1 if world_size == 2 else 0, + 'lm_head': 1 if world_size == 2 else 0, } pipe = pipeline(