From 371e3a21eee31b740e9a062e1f19d472a40494a1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Thu, 30 Nov 2023 23:25:03 +0000 Subject: [PATCH] .. --- tests/test_flash_triton_torch.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index f042dfd19c..454fda311d 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -176,8 +176,7 @@ def gen_bias(attn_impl: str): x1.requires_grad = True with torch.autocast(x0.device.type): - attn_bias = gen_bias(attn_impl_0) - + attn_bias_0 = gen_bias(attn_impl_0) rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( @@ -206,15 +205,15 @@ def gen_bias(attn_impl: str): y0, _, _ = attn0(x0, past_key_value=None, - attn_bias=attn_bias, + attn_bias=attn_bias_0, attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, attention_mask_in_length=attention_mask_in_length_0) - attn_bias = gen_bias(attn_impl_1) + attn_bias_1 = gen_bias(attn_impl_1) y1, _, _ = attn1(x1, past_key_value=None, - attn_bias=attn_bias, + attn_bias=attn_bias_1, attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True,