Skip to content

Commit

Permalink
Remove seq_len arg in rotary_emb (#1443)
Browse files Browse the repository at this point in the history
* remove seq_len in llama rotary_emb

* chore: lint

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
BMPixel and winglian authored Mar 26, 2024
1 parent 9ae17ae commit aee48ed
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 18 deletions.
15 changes: 2 additions & 13 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,7 @@ def flashattn_forward_with_s2attn(
# [bsz, nh, q_len, hd]
# pylint: disable=duplicate-code

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
Expand Down Expand Up @@ -435,13 +430,7 @@ def flashattn_forward(
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
Expand Down
6 changes: 1 addition & 5 deletions src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ def xformers_forward(
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(value_states)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
Expand Down

0 comments on commit aee48ed

Please sign in to comment.