Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 26, 2023
1 parent e82c723 commit a964aea
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,16 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
sequence_id is not None) and attn_uses_sequence_id and (attn_impl
== 'flash'):
assert S == sequence_id.shape[-1]
attention_mask_in_length = torch.nn.functional.one_hot(sequence_id)
if attention_mask is not None:
sequence_id = sequence_id.masked_fill(~attention_mask, S)
attention_mask_in_length = torch.nn.functional.one_hot(
sequence_id, num_classes=S + 1).sum(dim=1)[:, :-1]
attention_mask_in_length = attention_mask_in_length.masked_fill(
~attention_mask.unsqueeze(-1), 0)
attention_mask_in_length = attention_mask_in_length.sum(dim=1)
attention_mask_in_length = torch.nn.functional.pad(
attention_mask_in_length,
(0, S - attention_mask_in_length.shape[-1]),
mode='constant',
value=0)

return attention_mask_in_length

Expand Down

0 comments on commit a964aea

Please sign in to comment.