From 0a35298c1c8f397ccb7104f412377ea7bb8d09b1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 Mar 2024 13:43:27 -0400 Subject: [PATCH] chore: lint --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 6c4df60597..dda5da2b7a 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -284,9 +284,7 @@ def flashattn_forward_with_s2attn( # [bsz, nh, q_len, hd] # pylint: disable=duplicate-code - cos, sin = self.rotary_emb( - value_states, position_ids=position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) @@ -432,9 +430,7 @@ def flashattn_forward( # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - cos, sin = self.rotary_emb( - value_states, position_ids=position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids )