Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 30, 2023
1 parent b855100 commit 371e3a2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ def gen_bias(attn_impl: str):
x1.requires_grad = True

with torch.autocast(x0.device.type):
attn_bias = gen_bias(attn_impl_0)

attn_bias_0 = gen_bias(attn_impl_0)
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
Expand Down Expand Up @@ -206,15 +205,15 @@ def gen_bias(attn_impl: str):

y0, _, _ = attn0(x0,
past_key_value=None,
attn_bias=attn_bias,
attn_bias=attn_bias_0,
attention_mask=attention_mask,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
attention_mask_in_length=attention_mask_in_length_0)
attn_bias = gen_bias(attn_impl_1)
attn_bias_1 = gen_bias(attn_impl_1)
y1, _, _ = attn1(x1,
past_key_value=None,
attn_bias=attn_bias,
attn_bias=attn_bias_1,
attention_mask=attention_mask,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
Expand Down

0 comments on commit 371e3a2

Please sign in to comment.