diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index f5c38617fb..ed58f23e2a 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -20,6 +20,12 @@ ) from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.6.0'): + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + @pytest.mark.gpu @pytest.mark.skipif( @@ -33,9 +39,9 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int): # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) d = 128 n_heads = 8 @@ -67,9 +73,8 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int): } elif attn_impl == 'flex': extra_attn_kwargs = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': {}, } @@ -115,9 +120,8 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int): } elif attn_impl == 'flex': extra_attn_kwargs = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': {}, } @@ -156,9 +160,9 @@ def test_seq_id_masking_FA_v2(attn_impl: str): # Test that flash attention v2 with sequence id masking works correctly. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) d = 128 n_heads = 4 @@ -198,9 +202,8 @@ def test_seq_id_masking_FA_v2(attn_impl: str): extra_attn_kwargs['flash_attn_padding_info'] = flash_attn_padding_info_1 elif attn_impl == 'flex': extra_attn_kwargs = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': { 'sequence_id': sequence_id, }, @@ -246,9 +249,8 @@ def test_seq_id_masking_FA_v2(attn_impl: str): ] = flash_attn_padding_info_2 elif attn_impl == 'flex': extra_attn_kwargs = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': { 'sequence_id': sequence_id, }, @@ -300,9 +302,9 @@ def test_alibi_bias(attn_impl: str, n_heads: int): # Test that sliding window attention works as expected. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) dtype = torch.bfloat16 device = 'cuda' @@ -342,9 +344,8 @@ def test_alibi_bias(attn_impl: str, n_heads: int): } elif attn_impl == 'flex': extra_attn_kwargs = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': {}, } output_1, _, _ = attention_implementations.get(attn_impl)( @@ -444,9 +445,9 @@ def test_attn_logit_softcapping( # Test that attn_logit_softcapping in attention works as expected. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) if attn_impl == 'flex' and attn_logit_softcapping is not None: if int(attn_logit_softcapping) != attn_logit_softcapping: @@ -489,9 +490,8 @@ def test_attn_logit_softcapping( } elif attn_impl == 'flex': extra_attn_kwargs = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': {}, } output_1, _, _ = attention_implementations.get(attn_impl)( diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index fbb5989051..39331c1918 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -21,6 +21,12 @@ gen_sequence_id_info, ) +compiled_flex_attention = flex_attention +compiled_create_block_mask = create_block_mask +if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.6.0'): + compiled_flex_attention = torch.compile(flex_attention) + compiled_create_block_mask = torch.compile(create_block_mask) + def allclose_helper( t0: torch.Tensor, @@ -104,9 +110,9 @@ def test_attn_impl( """ if (attn_impl_0 == 'flex' or attn_impl_1 == 'flex') and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] @@ -293,9 +299,8 @@ def gen_bias(attn_impl: str): extra_kwargs = {} if attn_impl_0 == 'flex': extra_kwargs['flex_attn_kwargs'] = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': {}, } if sequence_id is not None: @@ -325,9 +330,8 @@ def gen_bias(attn_impl: str): extra_kwargs = {} if attn_impl_1 == 'flex': extra_kwargs['flex_attn_kwargs'] = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, } if sequence_id is not None: extra_kwargs['flex_attn_kwargs']['sequence_id_info'] = { @@ -386,9 +390,9 @@ def test_vs_mha(attn_impl: str, device: str = 'cuda'): """Compare diff attn_impl to torch.nn.MultiheadAttention.""" if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) from llmfoundry.models.layers import attention @@ -447,9 +451,8 @@ def gen_tca_mask(): extra_kwargs = {} if attn_impl == 'flex': extra_kwargs['flex_attn_kwargs'] = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': {}, } y0, _, _ = mmhsa( @@ -517,9 +520,9 @@ def test_grouped_attention_heads( """Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads.""" if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) from llmfoundry.models.layers import attention @@ -555,9 +558,8 @@ def test_grouped_attention_heads( extra_kwargs = {} if attn_impl == 'flex': extra_kwargs['flex_attn_kwargs'] = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': {}, } y0, _, _ = mmhsa( @@ -637,9 +639,9 @@ def test_reuse_prev_layer_kv_cache( """Checks reusing previous layer's kv cache.""" if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.6.0'): + ) < version.parse('2.5.1'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] @@ -777,9 +779,8 @@ def gen_bias(attn_impl: str): extra_kwargs = {} if attn_impl == 'flex': extra_kwargs['flex_attn_kwargs'] = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': { 'sequence_id': sequence_id, }, @@ -811,9 +812,8 @@ def gen_bias(attn_impl: str): extra_kwargs = {} if attn_impl == 'flex': extra_kwargs['flex_attn_kwargs'] = { - 'compiled_flex_attention': - flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. - 'compiled_create_block_mask': torch.compile(create_block_mask), + 'compiled_flex_attention': compiled_flex_attention, + 'compiled_create_block_mask': compiled_create_block_mask, 'sequence_id_info': { 'sequence_id': sequence_id, },