Skip to content

Commit

Permalink
Fixing the gen_attention_mask_in_length function to handle the case w…
Browse files Browse the repository at this point in the history
…hen sequence id is -1 due to attention masking (#940)

* ..

* undoing prev commit

* fixing the gen_attention_mask_in_length function to handle the case when sequence id is -1 due to attention masking

* Update modeling_mpt.py

* ..

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
ShashankMosaicML and dakinggg authored Feb 5, 2024
1 parent b9d2bfa commit ad126a6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
5 changes: 5 additions & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
raise ValueError(
f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).'
)
if attention_mask is not None:
# -1 is used to pad the sequence_id where attention mask is False (https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249).
# We replace those -1 with 0 to prevent `torch.nn.functional.one_hot(sequence_id)` in the next line from failing.
# We apply the attention mask again after the one_hot operation.
sequence_id = sequence_id.masked_fill(~attention_mask, 0)
attention_mask_in_length = torch.nn.functional.one_hot(sequence_id)
if attention_mask is not None:
attention_mask_in_length = attention_mask_in_length.masked_fill(
Expand Down
4 changes: 4 additions & 0 deletions tests/models/layers/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def test_attn_impl(attn_impl_0: str,
# zero out the last third of the attention mask
# to simulate padding
attention_mask[:, -s // 3:] = 0
if sequence_id is not None:
sequence_id = sequence_id.masked_fill(
~attention_mask, -1
) # Similar to how we set sequence id for padded tokens: https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249

def gen_bias(attn_impl: str):
causal = True
Expand Down

0 comments on commit ad126a6

Please sign in to comment.