diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 2177124740..79dc8c7f25 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -212,6 +212,11 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, raise ValueError( f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).' ) + if attention_mask is not None: + # -1 is used to pad the sequence_id where attention mask is False (https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249). + # We replace those -1 with 0 to prevent `torch.nn.functional.one_hot(sequence_id)` in the next line from failing. + # We apply the attention mask again after the one_hot operation. + sequence_id = sequence_id.masked_fill(~attention_mask, 0) attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) if attention_mask is not None: attention_mask_in_length = attention_mask_in_length.masked_fill( diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index d409486cc6..4e1efa3f34 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -134,6 +134,10 @@ def test_attn_impl(attn_impl_0: str, # zero out the last third of the attention mask # to simulate padding attention_mask[:, -s // 3:] = 0 + if sequence_id is not None: + sequence_id = sequence_id.masked_fill( + ~attention_mask, -1 + ) # Similar to how we set sequence id for padded tokens: https://github.com/mosaicml/llm-foundry/blob/706ea7dd40ba60a98dea5f37695d143d91c98b6c/llmfoundry/data/packing.py#L249 def gen_bias(attn_impl: str): causal = True