Skip to content

Commit

Permalink
updt pr comment
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Nov 11, 2023
1 parent fea8f95 commit 702fd24
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__(self, config: MPTConfig):
self.transformer: MPTModel = MPTModel(config)

self.lm_head = None
if config.tie_word_embeddings is False:
if not config.tie_word_embeddings:
self.lm_head = nn.Linear(
config.d_model,
config.vocab_size,
Expand Down Expand Up @@ -882,7 +882,7 @@ def flops_per_batch(self, batch: Mapping) -> int:

bs, msl = batch['input_ids'].shape[0:2]
params = self.n_active_params
if self.model.transformer.config.tie_word_embeddings is False:
if not self.model.transformer.config.tie_word_embeddings:
# embedding layers are lookup tables, therefore are not counted in the FLOP computation
params -= self.model.transformer.wte.weight.numel()
params_flops_per_token = 2 * params
Expand Down
6 changes: 1 addition & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,9 +954,7 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
torch.testing.assert_close(p1, p2)


@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_save_from_pretrained(tie_word_embeddings: bool,
tmp_path: pathlib.Path):
def test_save_from_pretrained(tmp_path: pathlib.Path):
# Test that MPT can be used with the HuggingFace
# save_pretrained/from_pretrained api.
hf_config = MPTConfig(
Expand All @@ -971,12 +969,10 @@ def test_save_from_pretrained(tie_word_embeddings: bool,
attn_config={
'attn_impl': 'torch',
},
tie_word_embeddings=tie_word_embeddings,
)
mpt = MPTForCausalLM(hf_config)

mpt.save_pretrained(tmp_path / 'test-save-pretrained')
print(tmp_path / 'test-save-pretrained')
mpt2 = MPTForCausalLM.from_pretrained(tmp_path / 'test-save-pretrained')

check_hf_model_equivalence(mpt, mpt2)
Expand Down

0 comments on commit 702fd24

Please sign in to comment.