From e758d7d1f89d9c7d7b0fc600594ec71922c14f48 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Sat, 7 Oct 2023 03:23:38 +0000 Subject: [PATCH] Ensure that the other ranks do not hit EOS --- tests/test_hf_mpt_gen.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py index 05b7fbad99..be541f4a26 100644 --- a/tests/test_hf_mpt_gen.py +++ b/tests/test_hf_mpt_gen.py @@ -102,9 +102,13 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, ): result = super().forward(input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, labels, return_dict, output_attentions, output_hidden_states, use_cache, inputs_embeds) - # Rank 0 should hit EOS immediately. + # Modify the logits to select the next token. if dist.get_global_rank() == 0: + # Rank 0 hits EOS immediately. result.logits[:, :, EOS_TOKEN_ID] = torch.inf + else: + # Other ranks do not hit EOS. + result.logits[:, :, EOS_TOKEN_ID] = -torch.inf return result def mock_from_config(config: MPTConfig, **_):