diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index 6a0bcfee18..9533fd5db1 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -177,9 +177,9 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): # Test that sliding window attention works as expected. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) dtype = torch.bfloat16 device = 'cuda' @@ -218,7 +218,7 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': {}, + 'sequence_id_info': {}, } output_1, _, _ = attention_implementations.get(attn_impl)( diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index c1315b9f5e..f5c38617fb 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -33,9 +33,9 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int): # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) d = 128 n_heads = 8 @@ -70,7 +70,7 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': {}, + 'sequence_id_info': {}, } output_1, _, _ = attention_implementations.get(attn_impl)( @@ -118,7 +118,7 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': {}, + 'sequence_id_info': {}, } output_2, _, _ = attention_implementations.get(attn_impl)( @@ -156,9 +156,9 @@ def test_seq_id_masking_FA_v2(attn_impl: str): # Test that flash attention v2 with sequence id masking works correctly. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) d = 128 n_heads = 4 @@ -201,7 +201,7 @@ def test_seq_id_masking_FA_v2(attn_impl: str): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': { + 'sequence_id_info': { 'sequence_id': sequence_id, }, } @@ -249,7 +249,7 @@ def test_seq_id_masking_FA_v2(attn_impl: str): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': { + 'sequence_id_info': { 'sequence_id': sequence_id, }, } @@ -300,9 +300,9 @@ def test_alibi_bias(attn_impl: str, n_heads: int): # Test that sliding window attention works as expected. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) dtype = torch.bfloat16 device = 'cuda' @@ -345,7 +345,7 @@ def test_alibi_bias(attn_impl: str, n_heads: int): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': {}, + 'sequence_id_info': {}, } output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, @@ -444,9 +444,9 @@ def test_attn_logit_softcapping( # Test that attn_logit_softcapping in attention works as expected. if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) if attn_impl == 'flex' and attn_logit_softcapping is not None: if int(attn_logit_softcapping) != attn_logit_softcapping: @@ -492,7 +492,7 @@ def test_attn_logit_softcapping( 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': {}, + 'sequence_id_info': {}, } output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 8dfeab193c..b7e99d1178 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -104,9 +104,9 @@ def test_attn_impl( """ if (attn_impl_0 == 'flex' or attn_impl_1 == 'flex') and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] @@ -300,10 +300,10 @@ def gen_bias(attn_impl: str): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': {}, + 'sequence_id_info': {}, } if sequence_id is not None: - extra_kwargs['flex_attn_kwargs']['sequence_id_transforms'][ + extra_kwargs['flex_attn_kwargs']['sequence_id_info'][ 'sequence_id'] = sequence_id y0, _, _ = attn0( x0, @@ -334,7 +334,7 @@ def gen_bias(attn_impl: str): 'compiled_create_block_mask': torch.compile(create_block_mask), } if sequence_id is not None: - extra_kwargs['flex_attn_kwargs']['sequence_id_transforms'] = { + extra_kwargs['flex_attn_kwargs']['sequence_id_info'] = { 'sequence_id': sequence_id, } @@ -390,9 +390,9 @@ def test_vs_mha(attn_impl: str, device: str = 'cuda'): """Compare diff attn_impl to torch.nn.MultiheadAttention.""" if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) from llmfoundry.models.layers import attention @@ -454,7 +454,7 @@ def gen_tca_mask(): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': {}, + 'sequence_id_info': {}, } y0, _, _ = mmhsa( x0, @@ -521,9 +521,9 @@ def test_grouped_attention_heads( """Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads.""" if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) from llmfoundry.models.layers import attention @@ -562,7 +562,7 @@ def test_grouped_attention_heads( 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': {}, + 'sequence_id_info': {}, } y0, _, _ = mmhsa( x0, @@ -641,9 +641,9 @@ def test_reuse_prev_layer_kv_cache( """Checks reusing previous layer's kv cache.""" if attn_impl == 'flex' and version.parse( torch.__version__.split('.dev')[0], - ) < version.parse('2.5.0'): + ) < version.parse('2.6.0'): pytest.skip( - 'FlexAttention is not supported in torch version {torch.__version__}<2.5.0.', + 'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.', ) alibi = pos_emb_config['alibi'] rope = pos_emb_config['rope'] @@ -786,7 +786,7 @@ def gen_bias(attn_impl: str): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': { + 'sequence_id_info': { 'sequence_id': sequence_id, }, } @@ -820,7 +820,7 @@ def gen_bias(attn_impl: str): 'compiled_flex_attention': flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis. 'compiled_create_block_mask': torch.compile(create_block_mask), - 'sequence_id_transforms': { + 'sequence_id_info': { 'sequence_id': sequence_id, }, }