Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 10, 2023
1 parent 62295e8 commit 6c96bd1
Showing 1 changed file with 23 additions and 31 deletions.
54 changes: 23 additions & 31 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,6 @@ def __init__(self, config: MPTConfig):
])
self.norm_f = norm_class(config.d_model, device=config.init_device)

self.lm_head = None
if config.tie_word_embeddings is False:
self.lm_head = nn.Linear(
config.d_model,
config.vocab_size,
bias=False,
device=config.init_device,
)
self.lm_head._fsdp_wrap = True

self.rope = config.attn_config['rope']
self.rope_impl = None
if self.rope:
Expand Down Expand Up @@ -248,23 +238,6 @@ def set_input_embeddings(
self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
self.wte = value

def get_output_embeddings(
self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
return self.lm_head or self.wte

def set_output_embeddings(
self, new_embeddings: Union[SharedEmbedding, nn.Embedding,
nn.Linear]) -> None:
if self.lm_head is not None:
self.lm_head = new_embeddings
else:
self.wte = new_embeddings

def tie_weights(self) -> None:
if self.lm_head is not None:
del self.lm_head
self.lm_head = None

@torch.no_grad()
def _attn_bias(
self,
Expand Down Expand Up @@ -606,6 +579,16 @@ def __init__(self, config: MPTConfig):

self.transformer: MPTModel = MPTModel(config)

self.lm_head = None
if config.tie_word_embeddings is False:
self.lm_head = nn.Linear(
config.d_model,
config.vocab_size,
bias=False,
device=config.init_device,
)
self.lm_head._fsdp_wrap = True

for child in self.transformer.children():
if isinstance(child, torch.nn.ModuleList):
continue
Expand Down Expand Up @@ -635,12 +618,21 @@ def set_input_embeddings(

def get_output_embeddings(
self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]:
return self.transformer.get_output_embeddings()
return self.lm_head or self.transformer.get_input_embeddings()

def set_output_embeddings(
self, new_embeddings: Union[SharedEmbedding, nn.Embedding,
nn.Linear]) -> None:
self.transformer.set_output_embeddings(new_embeddings)
if self.lm_head is not None:
self.lm_head = new_embeddings
else:
assert isinstance(new_embeddings, (SharedEmbedding, nn.Embedding))
self.transformer.set_input_embeddings(new_embeddings)

def tie_weights(self) -> None:
if self.lm_head is not None:
del self.lm_head
self.lm_head = None

def set_decoder(self, decoder: MPTModel) -> None:
self.transformer = decoder
Expand Down Expand Up @@ -684,8 +676,8 @@ def forward(
use_cache=use_cache,
)

if self.transformer.lm_head is not None:
logits = self.transformer.lm_head(outputs.last_hidden_state)
if self.lm_head is not None:
logits = self.lm_head(outputs.last_hidden_state)
else:
# move outputs to same device as weights for token embedding
# needed to support HF `device_map`
Expand Down

0 comments on commit 6c96bd1

Please sign in to comment.