Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 18, 2023
1 parent 511a405 commit 6af9aba
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
3 changes: 3 additions & 0 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
@pytest.mark.gpu
@pytest.mark.parametrize('kv_n_heads', [1, 2, 4, 8])
def test_gqa_kv_repetition(kv_n_heads: int):
# Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same
# whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own.
if not is_flash_v2_installed():
pytest.skip('GQA natively only supported by Flash Attention after v2.')
d = 128
Expand Down Expand Up @@ -82,6 +84,7 @@ def test_gqa_kv_repetition(kv_n_heads: int):

@pytest.mark.gpu
def test_seq_id_masking_FA_v2():
# Test that flash attention v2 with sequence id masking works correctly.
if not is_flash_v2_installed(v2_version='v2.1.2'):
pytest.skip(
'Using sequence id with flash attention requires flash attention v2.1.2 or higher.'
Expand Down
7 changes: 5 additions & 2 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_attn_impl(attn_impl_0: str,
assert s >= 8
sequence_id = torch.Tensor([[0] * 4 + [1] * (s - 4),
[0] * 8 + [1] * (s - 8)
]).to(device=device, dtype=torch.long)
]).to(device=device).long()

cfg.attn_impl = attn_impl_0
attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
Expand Down Expand Up @@ -140,7 +140,10 @@ def gen_bias(attn_impl: str):
)
if attn_impl != 'flash' and attn_uses_sequence_id and sequence_id is not None:
assert isinstance(attn_bias, torch.Tensor) # pyright
attn_bias = apply_sequence_id(attn_bias, sequence_id, s) # type: ignore
attn_bias = apply_sequence_id(
attn_bias,
sequence_id, # type: ignore
s)

return attn_bias

Expand Down
15 changes: 4 additions & 11 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,6 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
assert block.resid_ffn_dropout.p == 0.2



@pytest.mark.parametrize('attention_impl', [
'torch',
pytest.param('flash', marks=pytest.mark.gpu),
Expand Down Expand Up @@ -546,22 +545,16 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool):
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
def test_sequence_id_based_masking(attention_impl: str,
pos_emb_config: dict,
def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
# Testing the output of concatenated sequence with sequence id masking vs individual sequences.
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if not torch.cuda.is_available() and device == 'gpu':
pytest.skip(
f'This test requires CUDA to be available in order to run with {attention_impl} attention.'
)
alibi = pos_emb_config['alibi']
if alibi and attention_impl == 'flash':
pytest.skip(f'alibi only implemented with torch and triton attention.')

rope = pos_emb_config['rope']
if rope and pos_emb_config['rope_impl'] == 'dail' and (
device != 'gpu' or not is_flash_v2_installed()):
if rope and pos_emb_config[
'rope_impl'] == 'dail' and not is_flash_v2_installed():
pytest.skip(
f'dail implementation of rope requires gpu and flash attention 2.')

Expand All @@ -571,7 +564,7 @@ def test_sequence_id_based_masking(attention_impl: str,
'Using sequence id with flash attention requires flash attention v2.1.2 or higher.'
)

composer_device = get_device(device)
composer_device = get_device(None)

hf_config = MPTConfig(
init_device='cpu',
Expand Down

0 comments on commit 6af9aba

Please sign in to comment.