diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index f6ca4310b8..3ed824738f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -148,6 +148,7 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking. attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention. attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len) + Returns: attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: ```