Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 13, 2023
1 parent 1ea368c commit 1b073f4
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,13 +628,19 @@ def set_output_embeddings(
if self.lm_head is not None:
self.lm_head = new_embeddings
else:
assert isinstance(new_embeddings, (SharedEmbedding, nn.Embedding))
if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)):
raise ValueError(
'new_embeddings must be an instance of SharedEmbedding ' +
f'or nn.Embedding, but got {type(new_embeddings)}.')
warnings.warn(
'Using `set_output_embeddings` to set the embedding layer of ' +
'MPTForCausalLM with tied weights. Given weights are tied, ' +
'using `set_input_embeddings` is recommended over using ' +
'`set_output_embeddings`.')
self.transformer.set_input_embeddings(new_embeddings)

def tie_weights(self) -> None:
if self.lm_head is not None:
del self.lm_head
self.lm_head = None
self.lm_head = None

def set_decoder(self, decoder: MPTModel) -> None:
self.transformer = decoder
Expand Down

0 comments on commit 1b073f4

Please sign in to comment.