diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index db2af54631..06f74ba04d 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -522,7 +522,7 @@ def custom_forward(*inputs): *inputs, past_key_value, # pylint: disable=(cell-var-from-loop) output_attentions, - attention_mask=attention_mask, + padding_mask=padding_mask, ) return custom_forward