diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 92a0acc67c..86cd11eab2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -147,10 +147,16 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, sequence_id is not None) and attn_uses_sequence_id and (attn_impl == 'flash'): assert S == sequence_id.shape[-1] + attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) if attention_mask is not None: - sequence_id = sequence_id.masked_fill(~attention_mask, S) - attention_mask_in_length = torch.nn.functional.one_hot( - sequence_id, num_classes=S + 1).sum(dim=1)[:, :-1] + attention_mask_in_length = attention_mask_in_length.masked_fill( + ~attention_mask.unsqueeze(-1), 0) + attention_mask_in_length = attention_mask_in_length.sum(dim=1) + attention_mask_in_length = torch.nn.functional.pad( + attention_mask_in_length, + (0, S - attention_mask_in_length.shape[-1]), + mode='constant', + value=0) return attention_mask_in_length