diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 2f0dcb890d..a2a96246e7 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -183,16 +183,6 @@ def __init__(self, config: MPTConfig): ]) self.norm_f = norm_class(config.d_model, device=config.init_device) - self.lm_head = None - if config.tie_word_embeddings is False: - self.lm_head = nn.Linear( - config.d_model, - config.vocab_size, - bias=False, - device=config.init_device, - ) - self.lm_head._fsdp_wrap = True - self.rope = config.attn_config['rope'] self.rope_impl = None if self.rope: @@ -248,23 +238,6 @@ def set_input_embeddings( self, value: Union[SharedEmbedding, nn.Embedding]) -> None: self.wte = value - def get_output_embeddings( - self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]: - return self.lm_head or self.wte - - def set_output_embeddings( - self, new_embeddings: Union[SharedEmbedding, nn.Embedding, - nn.Linear]) -> None: - if self.lm_head is not None: - self.lm_head = new_embeddings - else: - self.wte = new_embeddings - - def tie_weights(self) -> None: - if self.lm_head is not None: - del self.lm_head - self.lm_head = None - @torch.no_grad() def _attn_bias( self, @@ -606,6 +579,16 @@ def __init__(self, config: MPTConfig): self.transformer: MPTModel = MPTModel(config) + self.lm_head = None + if config.tie_word_embeddings is False: + self.lm_head = nn.Linear( + config.d_model, + config.vocab_size, + bias=False, + device=config.init_device, + ) + self.lm_head._fsdp_wrap = True + for child in self.transformer.children(): if isinstance(child, torch.nn.ModuleList): continue @@ -635,12 +618,21 @@ def set_input_embeddings( def get_output_embeddings( self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]: - return self.transformer.get_output_embeddings() + return self.lm_head or self.transformer.get_input_embeddings() def set_output_embeddings( self, new_embeddings: Union[SharedEmbedding, nn.Embedding, nn.Linear]) -> None: - self.transformer.set_output_embeddings(new_embeddings) + if self.lm_head is not None: + self.lm_head = new_embeddings + else: + assert isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)) + self.transformer.set_input_embeddings(new_embeddings) + + def tie_weights(self) -> None: + if self.lm_head is not None: + del self.lm_head + self.lm_head = None def set_decoder(self, decoder: MPTModel) -> None: self.transformer = decoder @@ -684,8 +676,8 @@ def forward( use_cache=use_cache, ) - if self.transformer.lm_head is not None: - logits = self.transformer.lm_head(outputs.last_hidden_state) + if self.lm_head is not None: + logits = self.lm_head(outputs.last_hidden_state) else: # move outputs to same device as weights for token embedding # needed to support HF `device_map`