diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index f727c74b82..dda5da2b7a 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -284,12 +284,7 @@ def flashattn_forward_with_s2attn( # [bsz, nh, q_len, hd] # pylint: disable=duplicate-code - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb( - value_states, seq_len=kv_seq_len, position_ids=position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) @@ -435,13 +430,7 @@ def flashattn_forward( # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb( - value_states, seq_len=kv_seq_len, position_ids=position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 8143750f00..0c1a4e8224 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -80,11 +80,7 @@ def xformers_forward( # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids )