Skip to content

Commit

Permalink
Update llmfoundry/models/mpt/modeling_mpt.py
Browse files Browse the repository at this point in the history
Co-authored-by: Vitaliy Chiley <[email protected]>
  • Loading branch information
ShashankMosaicML and vchiley authored Nov 17, 2023
1 parent 84fa710 commit a560f31
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
# Since Flash Attention expects the masks to have same shape as the keys, we pad it with zeros.
key_attention_mask_in_length = torch.nn.functional.pad(key_attention_mask_in_length, (0, sequence_id.shape[-1] - S), value=0)

return query_attention_mask_in_length,key_attention_mask_in_length
return query_attention_mask_in_length, key_attention_mask_in_length

def apply_sequence_id(attn_bias: torch.Tensor,
sequence_id: torch.LongTensor,
Expand Down

0 comments on commit a560f31

Please sign in to comment.