diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index f0fa807fa6..f380c3f2ae 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -321,6 +321,8 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = key_states.shape == query_states.shape + dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: # special handling using sample packing qkv = torch.stack( @@ -330,7 +332,12 @@ def flashattn_forward( qkv = rearrange(qkv, "b s ... -> (b s) ...") output = flash_attn_varlen_qkvpacked_func( - qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True + qkv, + cu_seqlens, + max_seqlen, + dropout_p=dropout_rate, + softmax_scale=None, + causal=True, ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) elif query_states.shape == key_states.shape: @@ -353,7 +360,7 @@ def flashattn_forward( qkv_unpad, cu_seqlens_q, max_seqlen_q, - 0.0, + dropout_p=dropout_rate, softmax_scale=None, causal=is_causal, ) @@ -366,6 +373,7 @@ def flashattn_forward( output = flash_attn_kvpacked_func( query_states, torch.stack([key_states, value_states], 2), + dropout_p=dropout_rate, causal=is_causal, ) else: @@ -398,7 +406,7 @@ def flashattn_forward( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - 0.0, + dropout_p=dropout_rate, softmax_scale=None, causal=is_causal, ) diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 26b511d06e..e31864b838 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -201,6 +201,8 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = key_states.shape == query_states.shape + dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) + if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: # special handling using sample packing qkv = torch.stack( @@ -213,7 +215,7 @@ def flashattn_forward( qkv, cu_seqlens, max_seqlen, - 0.0, + dropout_p=dropout_rate, softmax_scale=None, causal=True, window_size=window_size, @@ -239,7 +241,7 @@ def flashattn_forward( qkv_unpad, cu_seqlens_q, max_seqlen_q, - 0.0, + dropout_p=dropout_rate, softmax_scale=None, causal=is_causal, window_size=window_size, @@ -253,6 +255,7 @@ def flashattn_forward( output = flash_attn_kvpacked_func( query_states, torch.stack([key_states, value_states], 2), + dropout_p=dropout_rate, causal=is_causal, window_size=window_size, ) @@ -286,7 +289,7 @@ def flashattn_forward( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - 0.0, + dropout_p=dropout_rate, softmax_scale=None, causal=is_causal, window_size=window_size,