Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Rajput committed Jan 2, 2024
1 parent 856ae68 commit b4a7752
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
10 changes: 4 additions & 6 deletions tests/models/layers/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import pytest
import torch
Expand Down Expand Up @@ -78,7 +77,7 @@ def test_attn_impl(attn_impl_0: str,
rope = pos_emb_config['rope']
if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'
) and not is_flash_v2_installed(v2_version='v2.4.2'):
pytest.skip('flash attention below v2.4.2 do not support alibi.')
pytest.skip('flash attention below v2.4.2 does not support alibi.')
if rope and (pos_emb_config['rope_impl']
== 'dail') and (not is_flash_v2_installed()):
pytest.skip('dail implementation of rope requires flash attention 2.')
Expand Down Expand Up @@ -129,8 +128,7 @@ def test_attn_impl(attn_impl_0: str,
# to simulate padding
attention_mask[:, -s // 3:] = 0

def gen_bias(attn_impl: str,
attention_mask_in_length: Union[torch.Tensor, None] = None):
def gen_bias(attn_impl: str):
causal = True
attn_bias = None
bs = attention.attn_bias_shape(attn_impl,
Expand Down Expand Up @@ -184,7 +182,7 @@ def gen_bias(attn_impl: str,
x1.requires_grad = True

with torch.autocast(x0.device.type):
attn_bias_0 = gen_bias(attn_impl_0, attention_mask_in_length_0)
attn_bias_0 = gen_bias(attn_impl_0)
rotary_emb_w_meta_info = None
if rope:
rotary_embedding = gen_rotary_embedding(
Expand Down Expand Up @@ -218,7 +216,7 @@ def gen_bias(attn_impl: str,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
is_causal=True,
attention_mask_in_length=attention_mask_in_length_0)
attn_bias_1 = gen_bias(attn_impl_1, attention_mask_in_length_1)
attn_bias_1 = gen_bias(attn_impl_1)
y1, _, _ = attn1(x1,
past_key_value=None,
attn_bias=attn_bias_1,
Expand Down
16 changes: 8 additions & 8 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict):
alibi = pos_emb_config['alibi']
if alibi and attention_impl == 'flash' and not is_flash_v2_installed(
v2_version='v2.4.2'):
pytest.skip(f'flash attention below v2.4.2 do not support alibi.')
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')

rope = pos_emb_config['rope']
if rope and pos_emb_config[
Expand Down Expand Up @@ -769,7 +769,7 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
alibi = pos_emb_config['alibi']
if alibi and attention_impl == 'flash' and not is_flash_v2_installed(
v2_version='v2.4.2'):
pytest.skip(f'flash attention below v2.4.2 do not support alibi.')
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')

rope = pos_emb_config['rope']
if rope and pos_emb_config[
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict,
if pos_emb_config[
'alibi'] and attention_impl == 'flash' and not is_flash_v2_installed(
v2_version='v2.4.2'):
pytest.skip(f'flash attention below v2.4.2 do not support alibi.')
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')

if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
Expand Down Expand Up @@ -1284,7 +1284,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, pos_emb_config: dict):
if pos_emb_config[
'alibi'] and attn_impl == 'flash' and not is_flash_v2_installed(
v2_version='v2.4.2'):
pytest.skip(f'flash attention below v2.4.2 do not support alibi.')
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
Expand Down Expand Up @@ -1423,7 +1423,7 @@ def test_forward_with_cache(attn_impl: str, pos_emb_config: dict,
if pos_emb_config[
'alibi'] and attn_impl == 'flash' and not is_flash_v2_installed(
v2_version='v2.4.2'):
pytest.skip(f'flash attention below v2.4.2 do not support alibi.')
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')

if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
Expand Down Expand Up @@ -1562,7 +1562,7 @@ def test_generate_with_past_kv(attn_impl: str, pos_emb_config: dict,
if pos_emb_config[
'alibi'] and attn_impl == 'flash' and not is_flash_v2_installed(
v2_version='v2.4.2'):
pytest.skip(f'flash attention below v2.4.2 do not support alibi.')
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')
if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
Expand Down Expand Up @@ -1671,7 +1671,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str,
if pos_emb_config[
'alibi'] and attn_impl == 'flash' and not is_flash_v2_installed(
v2_version='v2.4.2'):
pytest.skip(f'flash attention below v2.4.2 do not support alibi.')
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')

if pos_emb_config['rope'] and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
Expand Down Expand Up @@ -1862,7 +1862,7 @@ def test_forward_with_output_attentions_and_output_hidden_states(
if pos_emb_config[
'alibi'] and attn_impl == 'flash' and not is_flash_v2_installed(
v2_version='v2.4.2'):
pytest.skip(f'flash attention below v2.4.2 do not support alibi.')
pytest.skip(f'flash attention below v2.4.2 does not support alibi.')
if attn_impl in ['flash', 'triton']:
pytest.skip(f'output_attentions only implemented with torch attention.')
if pos_emb_config['rope'] and pos_emb_config[
Expand Down

0 comments on commit b4a7752

Please sign in to comment.