Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent 03113a9 commit 3351d23
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tests/models/layers/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,16 @@ def gen_tca_mask():
x1.requires_grad = True

with torch.autocast(x0.device.type):
flash_attn_padding_info = None
if attn_impl == 'flash':
flash_attn_padding_info = gen_flash_attn_padding_info(
n, s, 0, torch.device(device), None, attention_mask)
y0, _, _ = mmhsa(x0,
past_key_value=None,
attn_bias=None,
attention_mask=attention_mask,
is_causal=True)
is_causal=True,
flash_attn_padding_info=flash_attn_padding_info)
y1, _ = tmhsa(x1,
x1,
x1,
Expand Down Expand Up @@ -401,11 +406,16 @@ def test_grouped_attention_heads(attn_impl: str,
x0.requires_grad = True

with torch.autocast(x0.device.type):
flash_attn_padding_info = None
if attn_impl == 'flash':
flash_attn_padding_info = gen_flash_attn_padding_info(
n, s, 0, torch.device(device), None, attention_mask)
y0, _, _ = mmhsa(x0,
past_key_value=None,
attn_bias=None,
attention_mask=attention_mask,
is_causal=True)
is_causal=True,
flash_attn_padding_info=flash_attn_padding_info)
y0 *= attention_mask.unsqueeze(-1)

loss0 = y0.sum()
Expand Down

0 comments on commit 3351d23

Please sign in to comment.