diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 7260c89aae..91fb3d2fa2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -144,7 +144,7 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, # Since Flash Attention expects the masks to have same shape as the keys, we pad it with zeros. key_attention_mask_in_length = torch.nn.functional.pad(key_attention_mask_in_length, (0, sequence_id.shape[-1] - S), value=0) - return query_attention_mask_in_length,key_attention_mask_in_length + return query_attention_mask_in_length, key_attention_mask_in_length def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor,