From fec93fb2905679fd3b12c9ef794f5a66fc86bed3 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 11 Nov 2023 00:11:19 +0000 Subject: [PATCH] updt pr comment --- llmfoundry/models/mpt/modeling_mpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 83da422dff..3c61581cb8 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -580,7 +580,7 @@ def __init__(self, config: MPTConfig): self.transformer: MPTModel = MPTModel(config) self.lm_head = None - if config.tie_word_embeddings is False: + if not config.tie_word_embeddings: self.lm_head = nn.Linear( config.d_model, config.vocab_size, @@ -882,7 +882,7 @@ def flops_per_batch(self, batch: Mapping) -> int: bs, msl = batch['input_ids'].shape[0:2] params = self.n_active_params - if self.model.transformer.config.tie_word_embeddings is False: + if not self.model.transformer.config.tie_word_embeddings: # embedding layers are lookup tables, therefore are not counted in the FLOP computation params -= self.model.transformer.wte.weight.numel() params_flops_per_token = 2 * params