diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 83da422dff..3c61581cb8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -580,7 +580,7 @@ def __init__(self, config: MPTConfig): self.transformer: MPTModel = MPTModel(config) self.lm_head = None - if config.tie_word_embeddings is False: + if not config.tie_word_embeddings: self.lm_head = nn.Linear( config.d_model, config.vocab_size, @@ -882,7 +882,7 @@ def flops_per_batch(self, batch: Mapping) -> int: bs, msl = batch['input_ids'].shape[0:2] params = self.n_active_params - if self.model.transformer.config.tie_word_embeddings is False: + if not self.model.transformer.config.tie_word_embeddings: # embedding layers are lookup tables, therefore are not counted in the FLOP computation params -= self.model.transformer.wte.weight.numel() params_flops_per_token = 2 * params