From a70f05e4fdeb90e392b68a164a39390400427ca1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:26:44 -0800 Subject: [PATCH] Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 91fb3d2fa2..364f9a22a2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -147,8 +147,8 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, return query_attention_mask_in_length, key_attention_mask_in_length def apply_sequence_id(attn_bias: torch.Tensor, - sequence_id: torch.LongTensor, - max_seq_len: int) -> torch.Tensor: + sequence_id: torch.LongTensor, + max_seq_len: int) -> torch.Tensor: seq_len = sequence_id.shape[-1] if seq_len > max_seq_len: raise ValueError(