From 5f880939efda1dc6435a5cea24978fc4b0d3844e Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Tue, 3 Dec 2024 22:28:06 -0800 Subject: [PATCH] .. --- llmfoundry/models/mpt/modeling_mpt.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 969a8b56a2..0c1ab4305a 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -237,7 +237,6 @@ def gen_sequence_id_info( ```. (The description above is taken verbatim from https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/bert_padding.py#L125 .) """ - sequence_id_info = None if (sequence_id is not None) and attn_uses_sequence_id and ( attn_impl == 'flash' or attn_impl == 'flex' ): @@ -271,9 +270,9 @@ def gen_sequence_id_info( mode='constant', value=0, ) - sequence_id_info = attention_mask_in_length + return attention_mask_in_length - return sequence_id_info + return None def gen_flash_attn_padding_info(