Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 4, 2024
1 parent 5f88093 commit 661f7f6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
6 changes: 3 additions & 3 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):
# Test that sliding window attention works as expected.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
dtype = torch.bfloat16
device = 'cuda'
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {},
'sequence_id_info': {},
}

output_1, _, _ = attention_implementations.get(attn_impl)(
Expand Down
28 changes: 14 additions & 14 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int):
# whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
d = 128
n_heads = 8
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {},
'sequence_id_info': {},
}

output_1, _, _ = attention_implementations.get(attn_impl)(
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {},
'sequence_id_info': {},
}

output_2, _, _ = attention_implementations.get(attn_impl)(
Expand Down Expand Up @@ -156,9 +156,9 @@ def test_seq_id_masking_FA_v2(attn_impl: str):
# Test that flash attention v2 with sequence id masking works correctly.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
d = 128
n_heads = 4
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_seq_id_masking_FA_v2(attn_impl: str):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {
'sequence_id_info': {
'sequence_id': sequence_id,
},
}
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_seq_id_masking_FA_v2(attn_impl: str):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {
'sequence_id_info': {
'sequence_id': sequence_id,
},
}
Expand Down Expand Up @@ -300,9 +300,9 @@ def test_alibi_bias(attn_impl: str, n_heads: int):
# Test that sliding window attention works as expected.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
dtype = torch.bfloat16
device = 'cuda'
Expand Down Expand Up @@ -345,7 +345,7 @@ def test_alibi_bias(attn_impl: str, n_heads: int):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {},
'sequence_id_info': {},
}
output_1, _, _ = attention_implementations.get(attn_impl)(
query=query_1,
Expand Down Expand Up @@ -444,9 +444,9 @@ def test_attn_logit_softcapping(
# Test that attn_logit_softcapping in attention works as expected.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
if attn_impl == 'flex' and attn_logit_softcapping is not None:
if int(attn_logit_softcapping) != attn_logit_softcapping:
Expand Down Expand Up @@ -492,7 +492,7 @@ def test_attn_logit_softcapping(
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {},
'sequence_id_info': {},
}
output_1, _, _ = attention_implementations.get(attn_impl)(
query=query_1,
Expand Down
30 changes: 15 additions & 15 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def test_attn_impl(
"""
if (attn_impl_0 == 'flex' or attn_impl_1 == 'flex') and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
alibi = pos_emb_config['alibi']
rope = pos_emb_config['rope']
Expand Down Expand Up @@ -300,10 +300,10 @@ def gen_bias(attn_impl: str):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {},
'sequence_id_info': {},
}
if sequence_id is not None:
extra_kwargs['flex_attn_kwargs']['sequence_id_transforms'][
extra_kwargs['flex_attn_kwargs']['sequence_id_info'][
'sequence_id'] = sequence_id
y0, _, _ = attn0(
x0,
Expand Down Expand Up @@ -334,7 +334,7 @@ def gen_bias(attn_impl: str):
'compiled_create_block_mask': torch.compile(create_block_mask),
}
if sequence_id is not None:
extra_kwargs['flex_attn_kwargs']['sequence_id_transforms'] = {
extra_kwargs['flex_attn_kwargs']['sequence_id_info'] = {
'sequence_id': sequence_id,
}

Expand Down Expand Up @@ -390,9 +390,9 @@ def test_vs_mha(attn_impl: str, device: str = 'cuda'):
"""Compare diff attn_impl to torch.nn.MultiheadAttention."""
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
from llmfoundry.models.layers import attention

Expand Down Expand Up @@ -454,7 +454,7 @@ def gen_tca_mask():
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {},
'sequence_id_info': {},
}
y0, _, _ = mmhsa(
x0,
Expand Down Expand Up @@ -521,9 +521,9 @@ def test_grouped_attention_heads(
"""Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads."""
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
from llmfoundry.models.layers import attention

Expand Down Expand Up @@ -562,7 +562,7 @@ def test_grouped_attention_heads(
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {},
'sequence_id_info': {},
}
y0, _, _ = mmhsa(
x0,
Expand Down Expand Up @@ -641,9 +641,9 @@ def test_reuse_prev_layer_kv_cache(
"""Checks reusing previous layer's kv cache."""
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.0'):
) < version.parse('2.6.0'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
)
alibi = pos_emb_config['alibi']
rope = pos_emb_config['rope']
Expand Down Expand Up @@ -786,7 +786,7 @@ def gen_bias(attn_impl: str):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {
'sequence_id_info': {
'sequence_id': sequence_id,
},
}
Expand Down Expand Up @@ -820,7 +820,7 @@ def gen_bias(attn_impl: str):
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'sequence_id_transforms': {
'sequence_id_info': {
'sequence_id': sequence_id,
},
}
Expand Down

0 comments on commit 661f7f6

Please sign in to comment.