diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 06f74ba04d..97f0477649 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -520,9 +520,6 @@ def custom_forward(*inputs): # None for past_key_value return module( *inputs, - past_key_value, # pylint: disable=(cell-var-from-loop) - output_attentions, - padding_mask=padding_mask, ) return custom_forward @@ -532,7 +529,10 @@ def custom_forward(*inputs): hidden_states, attention_mask, position_ids, + past_key_value, + output_attentions, None, + padding_mask, cu_seqlens, max_seqlen, )