Skip to content

Commit

Permalink
updt pr comment
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 11, 2023
1 parent fea8f95 commit fec93fb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__(self, config: MPTConfig):
self.transformer: MPTModel = MPTModel(config)

self.lm_head = None
if config.tie_word_embeddings is False:
if not config.tie_word_embeddings:
self.lm_head = nn.Linear(
config.d_model,
config.vocab_size,
Expand Down Expand Up @@ -882,7 +882,7 @@ def flops_per_batch(self, batch: Mapping) -> int:

bs, msl = batch['input_ids'].shape[0:2]
params = self.n_active_params
if self.model.transformer.config.tie_word_embeddings is False:
if not self.model.transformer.config.tie_word_embeddings:
# embedding layers are lookup tables, therefore are not counted in the FLOP computation
params -= self.model.transformer.wte.weight.numel()
params_flops_per_token = 2 * params
Expand Down

0 comments on commit fec93fb

Please sign in to comment.