From 9923f337e31ab54c10fd87164e569567f6086afc Mon Sep 17 00:00:00 2001 From: BMPixel Date: Tue, 26 Mar 2024 19:59:21 +0800 Subject: [PATCH 1/2] remove seq_len in llama rotary_emb --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 11 ++--------- src/axolotl/monkeypatch/llama_attn_hijack_xformers.py | 6 +----- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index f727c74b82..6c4df60597 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -284,11 +284,8 @@ def flashattn_forward_with_s2attn( # [bsz, nh, q_len, hd] # pylint: disable=duplicate-code - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb( - value_states, seq_len=kv_seq_len, position_ids=position_ids + value_states, position_ids=position_ids ) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids @@ -435,12 +432,8 @@ def flashattn_forward( # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb( - value_states, seq_len=kv_seq_len, position_ids=position_ids + value_states, position_ids=position_ids ) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 8143750f00..0c1a4e8224 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -80,11 +80,7 @@ def xformers_forward( # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) From 0a35298c1c8f397ccb7104f412377ea7bb8d09b1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 Mar 2024 13:43:27 -0400 Subject: [PATCH 2/2] chore: lint --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 6c4df60597..dda5da2b7a 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -284,9 +284,7 @@ def flashattn_forward_with_s2attn( # [bsz, nh, q_len, hd] # pylint: disable=duplicate-code - cos, sin = self.rotary_emb( - value_states, position_ids=position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) @@ -432,9 +430,7 @@ def flashattn_forward( # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - cos, sin = self.rotary_emb( - value_states, position_ids=position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids )