Skip to content

Commit

Permalink
Ensure that the other ranks do not hit EOS
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 7, 2023
1 parent 655fc75 commit e758d7d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tests/test_hf_mpt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
):
result = super().forward(input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, labels, return_dict, output_attentions, output_hidden_states, use_cache, inputs_embeds)
# Rank 0 should hit EOS immediately.
# Modify the logits to select the next token.
if dist.get_global_rank() == 0:
# Rank 0 hits EOS immediately.
result.logits[:, :, EOS_TOKEN_ID] = torch.inf
else:
# Other ranks do not hit EOS.
result.logits[:, :, EOS_TOKEN_ID] = -torch.inf
return result

def mock_from_config(config: MPTConfig, **_):
Expand Down

0 comments on commit e758d7d

Please sign in to comment.