From 62ca4a2b71343f9857a9fbbc1162bb68cb0cacfa Mon Sep 17 00:00:00 2001 From: DreamGenX <157678800+DreamGenX@users.noreply.github.com> Date: Fri, 26 Jan 2024 13:43:37 +0100 Subject: [PATCH] Respect sliding_window=None (#1214) --- src/axolotl/monkeypatch/mistral_attn_hijack_flash.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index e31864b838..8e43da1110 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -94,7 +94,7 @@ def _prepare_decoder_attention_mask( sliding_window, ): # pylint: disable=unused-argument # [bsz, seq_len] - if attention_mask is None: + if attention_mask is None or sliding_window is None: return attention_mask # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios. @@ -151,7 +151,7 @@ def flashattn_forward( ) use_sliding_windows = ( - hasattr(self.config, "sliding_window") is not None + getattr(self.config, "sliding_window") is not None and kv_seq_len > self.config.sliding_window )