Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored and joecummings committed Jan 11, 2024
1 parent ebe42a5 commit 3019d1f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def flashattn_forward_with_s2attn(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
Expand Down Expand Up @@ -270,6 +270,7 @@ def flashattn_forward_with_s2attn(
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
# pylint: disable=duplicate-code

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand Down Expand Up @@ -318,7 +319,9 @@ def flashattn_forward_with_s2attn(
.permute(0, 3, 1, 2, 4, 5)
.reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim)
)
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x = rearrange( # pylint: disable=invalid-name
qkv, "b s three h d -> b s (three h d)"
)
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
cu_q_len_tmp = torch.arange(
0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype
Expand Down

0 comments on commit 3019d1f

Please sign in to comment.