diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index e31864b838..8e43da1110 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -94,7 +94,7 @@ def _prepare_decoder_attention_mask( sliding_window, ): # pylint: disable=unused-argument # [bsz, seq_len] - if attention_mask is None: + if attention_mask is None or sliding_window is None: return attention_mask # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios. @@ -151,7 +151,7 @@ def flashattn_forward( ) use_sliding_windows = ( - hasattr(self.config, "sliding_window") is not None + getattr(self.config, "sliding_window") is not None and kv_seq_len > self.config.sliding_window )