diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index b2bae0fbab7b..c044a3433990 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -244,7 +244,7 @@ def forward_decode( q_heads_1B4D, keys_1BPD, values_1BPD, - start_pos_ids, + cur_pos=start_pos_ids, scale=self.scale, program_config=self.model_config["SDPA_DECODE_PROGCFG"], compute_kernel_config=self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"], diff --git a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py index 5f251da552b1..947e166d3e28 100644 --- a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py @@ -361,7 +361,7 @@ def attn_mqa( query_layer, keys, values, - [start_pos for _ in range(self.max_batch_size)], + cur_pos=[start_pos for _ in range(self.max_batch_size)], scale=self.scale, program_config=program_config, compute_kernel_config=self.attention_config["COMPUTE_KERNEL_SDPA"], diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py index 170753f68ea7..0783e8c464ea 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention_decode.py @@ -48,18 +48,23 @@ def num_to_corerange(x): ) -def get_chunk_size(s): - if s <= 32: - return 32 - if s <= 64: - return 32 - if s <= 128: - return 32 - if s <= 256: - return 256 - if s <= 2048: - return 512 - return 512 +def get_chunk_size(max_start_pos, s): + if max_start_pos <= 32: + chunk_size = 32 + elif max_start_pos <= 64: + chunk_size = 32 + elif max_start_pos <= 128: + chunk_size = 32 + elif max_start_pos <= 1024: + chunk_size = 128 + else: + chunk_size = 512 + # find maximum power of 2 divisor of s + for i in range(1, s): + if s % (2 ** (i + 1)) != 0: + break + chunk_size = min(chunk_size, 2**i) + return chunk_size def fa_rand(*shape): @@ -217,7 +222,7 @@ def run_test_sdpa_decode_multi_pos( scale = d**-0.5 start_indices = np.linspace(0, max_start_idx, b, dtype=np.int32).tolist() if b > 1 else [max_start_idx] - k_chunk_size = get_chunk_size(max_start_idx + 1) + k_chunk_size = get_chunk_size(max_start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, # device.compute_with_storage_grid_size(), q_chunk_size=padded_num_heads, @@ -265,7 +270,7 @@ def run_test_sdpa_decode_multi_pos( tt_Q, tt_K, tt_V, - start_indices, + cur_pos=start_indices, scale=scale, program_config=program_config, compute_kernel_config=compute_kernel_config, @@ -290,7 +295,7 @@ def run_test_sdpa_decode_multi_pos( expect = torch.nn.functional.scaled_dot_product_attention( Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) + expect = expect.squeeze(2).unsqueeze(0) out_pass, out_pcc = comp_pcc(expect, tt_back, min_pcc) @@ -315,6 +320,7 @@ def run_test_sdpa_decode_single_iter( sharded_in=False, sharded_out=False, start_indices=None, + causal=True, ): compute_grid_size = device.compute_with_storage_grid_size() if grid_size[0] > compute_grid_size.x or grid_size[1] > compute_grid_size.y: @@ -355,7 +361,7 @@ def run_test_sdpa_decode_single_iter( max_start_idx = max(start_indices) scale = d**-0.5 - k_chunk_size = get_chunk_size(max_start_idx + 1) + k_chunk_size = get_chunk_size(max_start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, q_chunk_size=padded_num_heads, @@ -363,20 +369,29 @@ def run_test_sdpa_decode_single_iter( exp_approx_mode=False, ) - padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) + padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) if causal else s # Test various sequence lengths - logger.debug(f"Testing with sequence length: {max_start_idx}") + logger.debug(f"Testing with sequence length: {max_start_idx if causal else s}") logger.debug(f"Using chunk size: {k_chunk_size}") logger.debug(f"Using padded layer length: {padded_layer_len}") logger.debug(f"Using padded num heads: {padded_num_heads}") - attn_mask = torch.zeros((b, padded_num_heads, 1, padded_layer_len)) - for i in range(b): - start_idx = start_indices[i] - attn_mask[i, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min + if causal: + attn_mask = torch.zeros((b, nh, 1, padded_layer_len)) + for i in range(b): + start_idx = start_indices[i] + attn_mask[i, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min + else: + attn_mask = torch.bernoulli( + torch.full( + (b, nh, 1, padded_layer_len), + 0.25, + ) + ) + attn_mask = attn_mask * torch.finfo(torch.float32).min - Q = fa_rand(1, b, padded_num_heads, d) + Q = fa_rand(1, b, nh, d) tt_Q = ttnn.as_tensor( Q[:, :, :nh], @@ -385,24 +400,44 @@ def run_test_sdpa_decode_single_iter( layout=ttnn.TILE_LAYOUT, memory_config=height_sharded_memcfg if sharded_in else dram_memcfg, ) - if cur_pos_tensor: - start_indices_tt = ttnn.Tensor(torch.tensor(start_indices), ttnn.int32).to(device) - tt_back = ttnn.transformer.scaled_dot_product_attention_decode( - tt_Q, - tt_K, - tt_V, - cur_pos_tensor=start_indices_tt, - scale=scale, - program_config=program_config, - compute_kernel_config=compute_kernel_config, - memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, - ) + if causal: + if cur_pos_tensor: + start_indices_tt = ttnn.Tensor(torch.tensor(start_indices), ttnn.int32).to(device) + tt_back = ttnn.transformer.scaled_dot_product_attention_decode( + tt_Q, + tt_K, + tt_V, + cur_pos_tensor=start_indices_tt, + scale=scale, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, + ) + else: + tt_back = ttnn.transformer.scaled_dot_product_attention_decode( + tt_Q, + tt_K, + tt_V, + cur_pos=start_indices, + scale=scale, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, + ) else: + tt_mask = ttnn.as_tensor( + attn_mask.transpose(1, 2).contiguous(), + device=device, + dtype=ttnn.bfloat4_b, + layout=ttnn.TILE_LAYOUT, + memory_config=dram_memcfg, + ) tt_back = ttnn.transformer.scaled_dot_product_attention_decode( tt_Q, tt_K, tt_V, - start_indices, + is_causal=False, + attn_mask=tt_mask, scale=scale, program_config=program_config, compute_kernel_config=compute_kernel_config, @@ -425,7 +460,7 @@ def run_test_sdpa_decode_single_iter( expect = torch.nn.functional.scaled_dot_product_attention( Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) + expect = expect.squeeze(2).unsqueeze(0) non_skip_indices = torch.tensor(start_indices) != -1 out_pass, out_pcc = comp_pcc(expect[:, non_skip_indices], tt_back[:, non_skip_indices], min_pcc) @@ -483,6 +518,38 @@ def test_sdpa_decode( ) +@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") +@pytest.mark.parametrize( + "dtype, q_dtype", + [ + # [ttnn.bfloat16, ttnn.bfloat16], + [ttnn.bfloat8_b, ttnn.bfloat16], + ], + ids=[ + # "all_bfp16", + "kv_bfp8", + ], +) +@pytest.mark.parametrize( + "b, nh, nkv, s, d, grid_size", + ( + [32, 32, 8, 4224, 128, (8, 8)], # llama3.2 vision encoder on n150 + [8, 16, 4, 4224, 128, (8, 8)], # llama3.2 vision encoder on n300 + [32, 4, 1, 4224, 128, (8, 8)], # llama3.2 vision encoder on n300 + ), +) +def test_sdpa_decode_non_causal(device, b, nh, nkv, s, d, dtype, grid_size, q_dtype, use_program_cache): + if nkv > 1 and q_dtype != ttnn.bfloat16: + pytest.skip("nkv > 1 requires q_dtype to be bfloat16") + + ttnn.device.DisablePersistentKernelCache() + for _ in range(2): + run_test_sdpa_decode_single_iter( + device, b, nh, nkv, s, d, dtype, grid_size, q_dtype, sharded_in=False, sharded_out=False, causal=False + ) + assert device.num_program_cache_entries() == 1 + + @skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") @pytest.mark.parametrize( "dtype, q_dtype", @@ -620,16 +687,20 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc tt_page_table = ttnn.Tensor(page_table, ttnn.int32).to(device) max_start_idx = 0 + causal = True - while max_start_idx < s: + while max_start_idx < s or not causal: scale = d**-0.5 start_indices = np.linspace(max(max_start_idx - b, 0), max_start_idx, b, dtype=np.int32).tolist() # Test when page_table does not contain blocks for full sequence length - last_block = max(1, int(math.ceil((max_start_idx + 1) / block_size))) - tt_page_table = ttnn.Tensor(page_table[:, :last_block], ttnn.int32).to(device) + if causal: + last_block = max(1, int(math.ceil((max_start_idx + 1) / block_size))) + tt_page_table = ttnn.Tensor(page_table[:, :last_block], ttnn.int32).to(device) + else: + tt_page_table = ttnn.Tensor(page_table, ttnn.int32).to(device) - k_chunk_size = get_chunk_size(max_start_idx + 1) + k_chunk_size = get_chunk_size(max_start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, # device.compute_with_storage_grid_size(), q_chunk_size=padded_num_heads, @@ -637,20 +708,31 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc exp_approx_mode=False, ) - padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) + padded_layer_len = nearest_n(max_start_idx + 1, n=k_chunk_size) if causal else s # Test various sequence lengths - logger.info(f"Testing with sequence length: {max_start_idx}") + logger.debug( + f"Testing {'causal' if causal else 'non-causal'} with sequence length: {max_start_idx if causal else s}" + ) logger.info(f"Using chunk size: {k_chunk_size}") logger.info(f"Using padded layer length: {padded_layer_len}") logger.info(f"Using padded num heads: {padded_num_heads}") - attn_mask = torch.zeros((b, padded_num_heads, 1, padded_layer_len)) - for i in range(b): - start_idx = start_indices[i] - attn_mask[i, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min + if causal: + attn_mask = torch.zeros((b, padded_num_heads, 1, padded_layer_len)) + for i in range(b): + start_idx = start_indices[i] + attn_mask[i, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min + else: + attn_mask = torch.bernoulli( + torch.full( + (b, nh, 1, padded_layer_len), + 0.25, + ) + ) + attn_mask = attn_mask * torch.finfo(torch.float32).min - Q = fa_rand(1, b, padded_num_heads, d) + Q = fa_rand(1, b, nh, d) tt_Q = ttnn.as_tensor( Q[:, :, :nh], @@ -662,17 +744,38 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc start_indices_tt = ttnn.Tensor(torch.tensor(start_indices), ttnn.int32).to(device) - tt_back = ttnn.transformer.paged_scaled_dot_product_attention_decode( - tt_Q, - tt_K, - tt_V, - cur_pos_tensor=start_indices_tt, - page_table_tensor=tt_page_table, - scale=scale, - program_config=program_config, - compute_kernel_config=compute_kernel_config, - memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, - ) + if causal: + tt_back = ttnn.transformer.paged_scaled_dot_product_attention_decode( + tt_Q, + tt_K, + tt_V, + tt_page_table, + cur_pos_tensor=start_indices_tt, + scale=scale, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, + ) + else: + tt_mask = ttnn.as_tensor( + attn_mask.transpose(1, 2).contiguous(), + device=device, + dtype=ttnn.bfloat4_b, + layout=ttnn.TILE_LAYOUT, + memory_config=dram_memcfg, + ) + tt_back = ttnn.transformer.paged_scaled_dot_product_attention_decode( + tt_Q, + tt_K, + tt_V, + tt_page_table, + is_causal=False, + attn_mask=tt_mask, + scale=scale, + program_config=program_config, + compute_kernel_config=compute_kernel_config, + memory_config=height_sharded_memcfg if sharded_out else dram_memcfg, + ) tt_back = ttnn.to_torch(tt_back) @@ -692,7 +795,7 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc expect = torch.nn.functional.scaled_dot_product_attention( Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) + expect = expect.squeeze(2).unsqueeze(0) out_pass, out_pcc = comp_pcc(expect, tt_back, min_pcc) @@ -701,7 +804,13 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc assert out_pass max_start_idx += 71 if max_start_idx < 4096 else 3001 - # return + + if not causal: + # only run one iteration for non-causal + break + if max_start_idx >= s: + # run last iteration to test non-causal + causal = False @skip_for_blackhole("Unsupported on BH, see #12349") @@ -724,8 +833,8 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc @pytest.mark.parametrize( "b, nh, nkv, s, d, grid_size, cur_pos_tensor", ( - [32, 8, 1, 32768, 128, (8, 6), True], # Llama2-70B - [4, 32, 8, 4096, 128, (8, 8), True], # llama 3.1 8b + # [32, 8, 1, 32768, 128, (8, 6), True], # Llama2-70B + # [4, 32, 8, 4096, 128, (8, 8), True], # llama 3.1 8b # [4, 16, 4, 32768, 128, (8, 8), True], # [32, 32, 8, 4096, 128, (8, 8), True], # llama 3.1 8b [8, 16, 4, 4096, 128, (8, 2), True], # llama 3.1 8b N300 @@ -757,7 +866,7 @@ def test_sdpa_decode_paged_attention( sharded_out=False, ) - assert device.num_program_cache_entries() == 3 + assert device.num_program_cache_entries() == 4 @skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") @@ -985,7 +1094,7 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty scale = d**-0.5 - k_chunk_size = get_chunk_size(start_idx + 1) + k_chunk_size = get_chunk_size(start_idx + 1, s) program_config = ttnn.SDPAProgramConfig( compute_with_storage_grid_size=grid_size, # device.compute_with_storage_grid_size(), q_chunk_size=padded_num_heads, @@ -1013,7 +1122,7 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty expect = torch.nn.functional.scaled_dot_product_attention( Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) + expect = expect.squeeze(2).unsqueeze(0) all_out_pass = True @@ -1030,7 +1139,7 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty tt_Q, tt_K, tt_V, - [start_idx for _ in range(b)], + cur_pos=[start_idx for _ in range(b)], scale=scale, program_config=program_config, compute_kernel_config=compute_kernel_config, diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index b75c86d6296b..53cc54a8afc0 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -977,7 +977,7 @@ def run_test_sdpa_decode_single_iter( tt_Q, tt_K, tt_V, - [start_idx for _ in range(b)], + cur_pos=[start_idx for _ in range(b)], scale=scale, program_config=program_config, compute_kernel_config=compute_kernel_config, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp index a5d9f2df35c4..78045f01fce6 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp @@ -375,6 +375,8 @@ void MAIN { constexpr uint32_t k_chunk_size = get_compile_time_arg_val(17); constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(18); constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(19); + constexpr bool is_causal = get_compile_time_arg_val(20) == 1; + constexpr bool use_attention_mask = get_compile_time_arg_val(21) == 1; constexpr uint32_t q_chunk_tiles = Sq_chunk_t * DHt; constexpr uint32_t k_chunk_tiles = Sk_chunk_t * DHt; @@ -423,23 +425,25 @@ void MAIN { } // Get cur_pos - uint32_t cur_pos = 0; - // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list - if (cur_pos_arg != UINT32_MAX){ - cur_pos = cur_pos_arg; - } - else { - constexpr uint32_t cb_index_id = tt::CB::dataflow0; - cb_wait_front(cb_index_id, 1); - volatile uint32_t *index_addr_ptr; - cb_get_tile(cb_index_id, 0, &index_addr_ptr); - cur_pos = index_addr_ptr[4+cur_batch]; - cb_release_tile(cb_index_id); - } + uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position + if (is_causal) { + // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list + if (cur_pos_arg != UINT32_MAX){ + cur_pos = cur_pos_arg; + } + else { + constexpr uint32_t cb_index_id = tt::CB::dataflow0; + cb_wait_front(cb_index_id, 1); + volatile uint32_t *index_addr_ptr; + cb_get_tile(cb_index_id, 0, &index_addr_ptr); + cur_pos = index_addr_ptr[4+cur_batch]; + cb_release_tile(cb_index_id); + } - if (cur_pos == UINT32_MAX) { - // cur_pos of -1 indicates that the user should be skipped - return; + if (cur_pos == UINT32_MAX) { + // cur_pos of -1 indicates that the user should be skipped + return; + } } // Sequence length assignment auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = get_runtime_args(cur_pos, cur_batch, core_num_in_reduce, num_cores_per_head, k_chunk_size); @@ -464,11 +468,19 @@ void MAIN { /* QK *= SCALE */ mul_block_bcast_scalar_inplace(cb_qk_im, cb_scale_in, qk_chunk_tiles); - // For decode, we only apply mask at the last chunk on reducer cor - if (k_chunk == k_chunk_end - 1 && do_reduce) { - /* QK += MASK */ - reconfig_data_format(cb_qk_im, cb_mask_in); - add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); + if constexpr(is_causal){ + // For decode, we only apply mask at the last chunk on reducer core for causal mode + if (k_chunk == k_chunk_end - 1 && do_reduce) { + /* QK += MASK */ + reconfig_data_format(cb_qk_im, cb_mask_in); + add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); + } + } + else { + if constexpr(use_attention_mask){ + reconfig_data_format(cb_qk_im, cb_mask_in); + add_block_inplace(cb_qk_im, cb_mask_in, qk_chunk_tiles); + } } reconfig_data_format(cb_qk_im, cb_identity_scale_in); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp index 2e973fbcb8b3..32a8f28cf757 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/reader_decode_all.cpp @@ -50,6 +50,8 @@ void kernel_main() { constexpr uint32_t num_cores_per_head = get_compile_time_arg_val(14); constexpr uint32_t num_heads_per_core = get_compile_time_arg_val(15); constexpr uint32_t num_output_cores = get_compile_time_arg_val(16); + constexpr bool is_causal = get_compile_time_arg_val(17) == 1; + constexpr bool use_attention_mask = get_compile_time_arg_val(18) == 1; uint32_t arg_idx = 0; const uint32_t q_addr = get_arg_val(arg_idx++); @@ -57,6 +59,7 @@ void kernel_main() { const uint32_t v_addr = get_arg_val(arg_idx++); const uint32_t pos_addr = get_arg_val(arg_idx++); const uint32_t page_table_addr = get_arg_val(arg_idx++); + const uint32_t mask_addr = get_arg_val(arg_idx++); const uint32_t page_table_page_size = get_arg_val(arg_idx++); const bool is_worker = get_arg_val(arg_idx++) == 0; const bool is_output_core = get_arg_val(arg_idx++) == 1; @@ -71,32 +74,34 @@ void kernel_main() { return; } // Get cur_pos - uint32_t cur_pos = 0; - // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list - if (cur_pos_arg != UINT32_MAX){ - cur_pos = cur_pos_arg; - } - else { - constexpr uint32_t cb_index_id = tt::CB::dataflow0; - const InterleavedAddrGen addrg = { - .bank_base_address = pos_addr, - .page_size = index_stick_size_B - }; - - cb_reserve_back(cb_index_id, 1); - uint32_t index_cb_wr_ptr = get_write_ptr(cb_index_id); - // index_tensor has one page to read - uint64_t tensor_index_noc_addr = get_noc_addr(0, addrg); - noc_async_read(tensor_index_noc_addr, index_cb_wr_ptr, index_stick_size_B); - noc_async_read_barrier(); - cb_push_back(cb_index_id, 1); - volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast(index_cb_wr_ptr); - cur_pos = index_ptr[cur_batch]; - } + uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position + if (is_causal) { + // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list + if (cur_pos_arg != UINT32_MAX){ + cur_pos = cur_pos_arg; + } + else { + constexpr uint32_t cb_index_id = tt::CB::dataflow0; + const InterleavedAddrGen addrg = { + .bank_base_address = pos_addr, + .page_size = index_stick_size_B + }; + + cb_reserve_back(cb_index_id, 1); + uint32_t index_cb_wr_ptr = get_write_ptr(cb_index_id); + // index_tensor has one page to read + uint64_t tensor_index_noc_addr = get_noc_addr(0, addrg); + noc_async_read(tensor_index_noc_addr, index_cb_wr_ptr, index_stick_size_B); + noc_async_read_barrier(); + cb_push_back(cb_index_id, 1); + volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast(index_cb_wr_ptr); + cur_pos = index_ptr[cur_batch]; + } - if (cur_pos == UINT32_MAX) { - // cur_pos of -1 indicates that the user should be skipped - return; + if (cur_pos == UINT32_MAX) { + // cur_pos of -1 indicates that the user should be skipped + return; + } } const uint32_t valid_seq_len_tiles = (cur_pos + 1 + 32 - 1) / 32; @@ -137,6 +142,7 @@ void kernel_main() { constexpr uint32_t cb_q_in = tt::CB::c_in0; constexpr uint32_t cb_k_in = tt::CB::c_in1; constexpr uint32_t cb_v_in = tt::CB::c_in2; + constexpr uint32_t cb_mask_in = tt::CB::c_in3; constexpr uint32_t onetile = 1; @@ -146,6 +152,8 @@ void kernel_main() { constexpr DataFormat k_data_format = get_dataformat(cb_k_in); constexpr uint32_t v_tile_bytes = get_tile_size(cb_v_in); constexpr DataFormat v_data_format = get_dataformat(cb_v_in); + constexpr uint32_t mask_tile_bytes = get_tile_size(cb_mask_in); + constexpr DataFormat mask_data_format = get_dataformat(cb_mask_in); constexpr uint32_t barrier_threshold = get_barrier_read_threshold(); uint32_t barrier_count = 0; @@ -202,7 +210,16 @@ void kernel_main() { .data_format = v_data_format }; + const InterleavedAddrGenFast mask_reader = { + .bank_base_address = mask_addr, + .page_size = mask_tile_bytes, + .data_format = mask_data_format + }; + for (uint32_t cur_head = cur_head_group*num_heads_per_core; cur_head < cur_head_group*num_heads_per_core + num_heads_per_core; ++cur_head) { + const uint32_t mask_batch_offset = (cur_batch % Bkv) * PNHt * St; + const uint32_t mask_chunk_offset = k_chunk_start * Sk_chunk_t; + uint32_t mask_start_tile_id = mask_batch_offset + mask_chunk_offset; if constexpr (is_paged_attention) { for (uint32_t k_chunk = k_chunk_start; k_chunk < k_chunk_end; ++k_chunk) { @@ -229,6 +246,29 @@ void kernel_main() { noc_async_read_barrier(); cb_push_back(cb_k_in, k_chunk_tiles); + if constexpr(use_attention_mask){ + // Read mask chunk + cb_reserve_back(cb_mask_in, mask_chunk_tiles); + uint32_t mask_write_ptr = get_write_ptr(cb_mask_in); + barrier_count = 0; + for (uint32_t row = 0; row < PNHt; ++row) { + uint32_t mask_tile_id = mask_start_tile_id + row * PSt; + for (uint32_t col = 0; col < Sk_chunk_t; ++col) { + noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr); + mask_tile_id++; + mask_write_ptr += mask_tile_bytes; + + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } + } + } + noc_async_read_barrier(); + cb_push_back(cb_mask_in, mask_chunk_tiles); + mask_start_tile_id += mask_chunk_tiles; + } + // Read V chunk in row major order, write in row-major order cb_reserve_back(cb_v_in, k_chunk_tiles); uint32_t v_write_ptr = get_write_ptr(cb_v_in); @@ -289,6 +329,29 @@ void kernel_main() { cb_push_back(cb_k_in, k_chunk_tiles); k_start_tile_id += k_chunk_tiles; + if constexpr(use_attention_mask){ + // Read mask chunk + cb_reserve_back(cb_mask_in, mask_chunk_tiles); + uint32_t mask_write_ptr = get_write_ptr(cb_mask_in); + barrier_count = 0; + for (uint32_t row = 0; row < PNHt; ++row) { + uint32_t mask_tile_id = mask_start_tile_id + row * PSt; + for (uint32_t col = 0; col < Sk_chunk_t; ++col) { + noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr); + mask_tile_id++; + mask_write_ptr += mask_tile_bytes; + + if (++barrier_count == barrier_threshold) { + noc_async_read_barrier(); + barrier_count = 0; + } + } + } + noc_async_read_barrier(); + cb_push_back(cb_mask_in, mask_chunk_tiles); + mask_start_tile_id += mask_chunk_tiles; + } + // Read V chunk cb_reserve_back(cb_v_in, k_chunk_tiles); uint32_t v_write_ptr = get_write_ptr(cb_v_in); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp index 7374a7594f65..fdc083b391c5 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/dataflow/writer_decode_all.cpp @@ -244,6 +244,7 @@ void kernel_main() { constexpr uint32_t num_reducer_cores = get_compile_time_arg_val(16); constexpr uint32_t num_output_cores = get_compile_time_arg_val(17); constexpr uint32_t ELEMENT_SIZE = get_compile_time_arg_val(18); + constexpr bool is_causal = get_compile_time_arg_val(19) == 1; uint32_t arg_idx = 0; const uint32_t out_addr = get_arg_val(arg_idx++); @@ -262,22 +263,24 @@ void kernel_main() { return; } // Get cur_pos - uint32_t cur_pos = 0; + uint32_t cur_pos = St*32-1; // default to non-causal, which we do attention on the entire kv cache. In this case we set cur_pos to the last position // using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list - if (cur_pos_arg != UINT32_MAX){ - cur_pos = cur_pos_arg; - } - else { - constexpr uint32_t cb_index_id = tt::CB::dataflow0; - cb_wait_front(cb_index_id, 1); - uint32_t index_cb_ptr = get_read_ptr(cb_index_id); - volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast(index_cb_ptr); - cur_pos = index_ptr[cur_batch]; - } + if (is_causal) { + if (cur_pos_arg != UINT32_MAX){ + cur_pos = cur_pos_arg; + } + else { + constexpr uint32_t cb_index_id = tt::CB::dataflow0; + cb_wait_front(cb_index_id, 1); + uint32_t index_cb_ptr = get_read_ptr(cb_index_id); + volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast(index_cb_ptr); + cur_pos = index_ptr[cur_batch]; + } - if (cur_pos == UINT32_MAX) { - // cur_pos of -1 indicates that the user should be skipped - return; + if (cur_pos == UINT32_MAX) { + // cur_pos of -1 indicates that the user should be skipped + return; + } } // Sequence length assignment auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = get_runtime_args(cur_pos, cur_batch, core_num_in_reduce, num_cores_per_head, k_chunk_size); @@ -347,8 +350,8 @@ void kernel_main() { constexpr uint32_t barrier_threshold = get_barrier_read_threshold(); uint32_t barrier_count = 0; - // generate and send mask to compute - generate_mask(k_num_chunks, PSt, cur_pos); + // generate and send mask to compute if causal + if constexpr(is_causal) generate_mask(k_num_chunks, PSt, cur_pos); for (uint32_t cur_head = cur_head_group*num_heads_per_core; cur_head < cur_head_group*num_heads_per_core + num_heads_per_core; ++cur_head) { if (k_chunk_end - k_chunk_start < k_num_chunks){ diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp index 476e0bbd57e7..8baed071307b 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.cpp @@ -44,30 +44,60 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); } + if (!this->is_causal) { + if (optional_input_tensors.at(2).has_value()){ + // Causal attention verification + const auto& mask_tensor = optional_input_tensors.at(2).value(); + const auto mask_shape = mask_tensor.get_legacy_shape(); + const auto mask_shape_unpadded = mask_tensor.get_shape(); + + TT_FATAL(mask_shape[2] == q_shape[2], "Expect same number of padded heads in mask as in Q, got {} and {}", mask_shape[2], q_shape[2]); + TT_FATAL(mask_shape_unpadded[2] == q_shape_unpadded[2], "Expect same number of heads in mask as in Q, got {} and {}", mask_shape_unpadded[3], q_shape_unpadded[2]); + if (! this->paged_attention) TT_FATAL(mask_shape[3] == k_shape[2], "Expect same sequence length in mask as in K, got {} and {}", mask_shape[3], k_shape[2]); + TT_FATAL(mask_shape[3] % k_chunk_size == 0, "Mask sequence length must be multiple of chunk size, got: {} and {}", mask_shape[3], k_chunk_size); + + TT_FATAL( + mask_tensor.get_dtype() == DataType::BFLOAT16 || mask_tensor.get_dtype() == DataType::BFLOAT8_B || + mask_tensor.get_dtype() == DataType::BFLOAT4_B, + "Unsupported data type for mask tensor: {}.", + mask_tensor.get_dtype()); + } + } else { + // Uncausal attention verification + TT_FATAL(not optional_input_tensors.at(2).has_value(), "Must not have attn_mask tensor for non-causal attention"); + } + if (this->paged_attention) { // Paged attention verification TT_FATAL(! this->share_cache.value_or(false), "Share cache feature not supported for paged attention"); - TT_FATAL(optional_input_tensors.at(0).has_value(), "Must have cur_pos tensor for paged attention"); - TT_FATAL(optional_input_tensors.at(1).has_value(), "Must have page_table tensor for paged attention"); + const auto B = q_shape[1]; - const auto& cur_pos_tensor = optional_input_tensors.at(0).value(); - const auto& page_table_tensor = optional_input_tensors.at(1).value(); + if (this->is_causal) { + // Check cur pos tensor for causal mode + TT_FATAL(optional_input_tensors.at(0).has_value(), "Must have cur_pos tensor for paged attention in causal mode"); + const auto& cur_pos_tensor = optional_input_tensors.at(0).value(); + TT_FATAL(cur_pos_tensor.get_dtype() == DataType::INT32, "Expect cur_pos to be INT32, got {}", cur_pos_tensor.get_dtype()); + TT_FATAL(cur_pos_tensor.get_layout() == Layout::ROW_MAJOR, "Expect cur_pos to be ROW_MAJOR, got {}", cur_pos_tensor.get_layout()); + const auto cur_pos_shape = cur_pos_tensor.get_legacy_shape(); + TT_FATAL(cur_pos_shape[0] == B, "cur_pos must have batch size equal to Q, got {} and {}", cur_pos_shape[0], B); + } - TT_FATAL(cur_pos_tensor.get_dtype() == DataType::INT32, "Error"); - TT_FATAL(cur_pos_tensor.get_layout() == Layout::ROW_MAJOR, "Error"); + TT_FATAL(optional_input_tensors.at(1).has_value(), "Must have page_table tensor for paged attention"); + const auto& page_table_tensor = optional_input_tensors.at(1).value(); TT_FATAL(page_table_tensor.get_dtype() == DataType::INT32, "Error"); TT_FATAL(page_table_tensor.get_layout() == Layout::ROW_MAJOR, "Error"); - const auto cur_pos_shape = cur_pos_tensor.get_legacy_shape(); const auto page_table_shape = page_table_tensor.get_legacy_shape(); - const auto B = q_shape[1]; - TT_FATAL(cur_pos_shape[0] == B, "cur_pos must have batch size equal to Q"); TT_FATAL(page_table_shape[0] == B, "page_table must have hidden size equal to Q"); TT_FATAL(k_shape[2] == v_shape[2], "K and V must have same block size"); TT_FATAL(k_shape[3] == v_shape[3] && k_shape[3] == q_shape[3], "Q, K, V must have same hidden size"); + + // Validate chunk size for paged version + TT_FATAL(k_chunk_size % 32 == 0, "Chunk size must be multiple of 32, got: {}", k_chunk_size); + if (! this->is_causal) TT_FATAL((page_table_shape[1]*k_shape[2]) % k_chunk_size == 0, "K sequence length must be multiple of chunk size, got: {} and {}", page_table_shape[1]*k_shape[2], k_chunk_size); } else { // Unpaged attention verification TT_FATAL(not optional_input_tensors.at(1).has_value(), "Must not have page_table tensor for unpaged attention"); @@ -93,6 +123,10 @@ void ScaledDotProductAttentionDecode::validate(const std::vector& input_ // Check sequence lengths TT_FATAL(k_shape[-2] == v_shape[-2], "Error"); + // Validate chunk size for unpaged version + TT_FATAL(k_chunk_size % 32 == 0, "Chunk size must be multiple of 32, got: {}", k_chunk_size); + TT_FATAL(k_shape[2] % k_chunk_size == 0, "K sequence length must be multiple of chunk size, got: {} and {}", k_shape[2], k_chunk_size); + // Check hidden size const auto D = q_shape[-1]; TT_FATAL(k_shape[-1] == D, "Error"); @@ -134,6 +168,7 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( auto& cur_pos_tensor = optional_input_tensors.at(0); auto& page_table_tensor = optional_input_tensors.at(1); + auto& attn_mask = optional_input_tensors.at(2); auto& output_tensor = output_tensors.at(0); @@ -148,7 +183,9 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( input_tensor_v, cur_pos_tensor, page_table_tensor, + attn_mask, output_tensor, + this->is_causal, this->cur_pos, scale, this->compute_kernel_config, @@ -158,6 +195,8 @@ operation::ProgramWithCallbacks ScaledDotProductAttentionDecode::create_program( } operation::Hash ScaledDotProductAttentionDecode::compute_program_hash(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { + bool has_cur_pos = optional_input_tensors.at(0).has_value(); + bool has_attn_mask = optional_input_tensors.at(2).has_value(); return operation::hash_operation( this->scale, this->output_mem_config, @@ -165,6 +204,9 @@ operation::Hash ScaledDotProductAttentionDecode::compute_program_hash(const std: this->compute_kernel_config, this->k_chunk_size, this->paged_attention, + this->is_causal, + has_attn_mask, + has_cur_pos, input_tensors); } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp index 626f10f133df..7993a1a96b2d 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_op.hpp @@ -14,6 +14,7 @@ namespace ttnn::operations::transformer { struct ScaledDotProductAttentionDecode { + const bool is_causal; std::vector cur_pos; const std::optional scale; const MemoryConfig output_mem_config; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp index 2c0a99f0389b..92d88c002f40 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.cpp @@ -25,7 +25,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( const Tensor& input_tensor_v, std::optional cur_pos_tensor, std::optional page_table_tensor, + std::optional attn_mask, const Tensor& output_tensor, + bool is_causal, const std::vector& cur_pos_ids, std::optional scale, DeviceComputeKernelConfig compute_kernel_config, @@ -59,6 +61,8 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( if (is_paged_attention) { uint32_t block_size = k_shape[2]; page_block_size_t = block_size / TILE_HEIGHT; + // get real S using the page_table_tensor + S = page_table_tensor.value().get_legacy_shape()[-1]*S; } uint32_t Bkv = k_shape[0]; uint32_t St = S/TILE_HEIGHT; @@ -102,6 +106,10 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( auto out0_buffer = output_tensor.buffer(); bool use_cur_pos_tensor = cur_pos_tensor.has_value(); + bool use_attention_mask = attn_mask.has_value(); + + log_debug("use_cur_pos_tensor: {}", use_cur_pos_tensor); + log_debug("use_attention_mask: {}", use_attention_mask); // Parallelization scheme // We will assign cores to batches @@ -262,6 +270,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( tt::DataFormat q_df = tt_metal::datatype_to_dataformat_converter(input_tensor_q.get_dtype()); tt::DataFormat k_df = tt_metal::datatype_to_dataformat_converter(input_tensor_k.get_dtype()); tt::DataFormat v_df = tt_metal::datatype_to_dataformat_converter(input_tensor_v.get_dtype()); + tt::DataFormat mask_df = use_attention_mask ? tt_metal::datatype_to_dataformat_converter(attn_mask.value().get_dtype()) : tt::DataFormat::Float16_b; tt::DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); tt::DataFormat scalar_df = tt::DataFormat::Float16_b; tt::DataFormat im_df = tt::DataFormat::Float16_b; @@ -271,6 +280,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t q_tile_size = tt_metal::detail::TileSize(q_df); uint32_t k_tile_size = tt_metal::detail::TileSize(k_df); uint32_t v_tile_size = tt_metal::detail::TileSize(v_df); + uint32_t mask_tile_size = tt_metal::detail::TileSize(mask_df); uint32_t out_tile_size = tt_metal::detail::TileSize(out_df); uint32_t scalar_tile_size = tt_metal::detail::TileSize(scalar_df); uint32_t im_tile_size = tt_metal::detail::TileSize(im_df); @@ -329,7 +339,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( auto cb_in2_id = CreateCircularBuffer(program, core_grid, c_in2_config); // attn_mask input - auto c_in3_config = CircularBufferConfig(qk_tiles * stats_tile_size, {{CB::c_in3, stats_df}}).set_page_size(CB::c_in3, stats_tile_size); + auto c_in3_config = CircularBufferConfig(qk_tiles * mask_tile_size, {{CB::c_in3, mask_df}}).set_page_size(CB::c_in3, mask_tile_size); auto cb_in3_id = CreateCircularBuffer(program, core_grid, c_in3_config); // scale input @@ -486,7 +496,8 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( B, PNHt, St, DHt, Sk_chunk_t, num_active_cores, is_q_sharded, num_cores_per_batch, k_chunk_size, index_stick_size, (uint32_t)is_paged_attention, num_kv_heads, page_block_size_t, - Bkv, num_cores_per_head, num_heads_per_core, num_output_cores + Bkv, num_cores_per_head, num_heads_per_core, num_output_cores, + is_causal, use_attention_mask, }; std::vector writer_compile_time_args_common = { @@ -505,14 +516,15 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( num_heads_per_core, num_reducer_cores, num_output_cores, - output_tensor.element_size() + output_tensor.element_size(), + is_causal, }; std::vector compute_compile_time_args_common = { St, DHt, PNHt, Sk_chunk_t, qk_in0_block_w, qk_out_subblock_w, qk_out_subblock_h, qk_in0_num_subblocks, qk_in1_num_subblocks, qk_num_blocks, out_in0_block_w, out_out_subblock_w, out_out_subblock_h, out_in0_num_subblocks, out_in1_num_subblocks, out_num_blocks, - num_cores_per_batch, k_chunk_size, num_cores_per_head, num_heads_per_core + num_cores_per_batch, k_chunk_size, num_cores_per_head, num_heads_per_core, is_causal, use_attention_mask, }; std::map defines; @@ -562,6 +574,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t v_addr = v_buffer->address(); uint32_t pos_addr = use_cur_pos_tensor ? cur_pos_tensor.value().buffer()->address() : 0; uint32_t page_table_addr = is_paged_attention ? page_table_tensor.value().buffer()->address() : 0; + uint32_t attn_mask_addr = use_attention_mask ? attn_mask.value().buffer()->address() : 0; uint32_t out_addr = out0_buffer->address(); // Set rt args @@ -577,7 +590,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t core_num_in_reduce = i % num_cores_per_head; uint32_t core_num_in_output = i % num_cores_per_batch; - uint32_t cur_pos = use_cur_pos_tensor ? -1 : cur_pos_ids.at(cur_batch); + uint32_t cur_pos = (use_cur_pos_tensor || ! is_causal) ? -1 : cur_pos_ids.at(cur_batch); log_debug("---- core_id: {}, coord: {} ----", i, core); log_debug("worker_id_for_reduce: {}", worker_id_for_reduce); @@ -591,7 +604,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( log_debug("cur_pos: {}", cur_pos); // reader runtime args - std::vector reader_rt_args = { q_addr, k_addr, v_addr, pos_addr, page_table_addr, page_table_stick_size, do_reduce, do_output, cur_head, cur_batch, core_num_in_reduce, core_num_in_output, cur_pos}; + std::vector reader_rt_args = { q_addr, k_addr, v_addr, pos_addr, page_table_addr, attn_mask_addr, page_table_stick_size, do_reduce, do_output, cur_head, cur_batch, core_num_in_reduce, core_num_in_output, cur_pos}; reader_rt_args.insert(reader_rt_args.end(), output_core_physical_xs.begin(), output_core_physical_xs.end()); reader_rt_args.insert(reader_rt_args.end(), output_core_physical_ys.begin(), output_core_physical_ys.end()); @@ -640,7 +653,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( cb_out4_id, B, use_cur_pos_tensor, - is_paged_attention + use_attention_mask, + is_paged_attention, + is_causal ] ( const void* operation, @@ -662,6 +677,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t v_addr = v_buffer->address(); uint32_t pos_addr = use_cur_pos_tensor ? optional_input_tensors.at(0).value().buffer()->address() : 0; uint32_t page_table_addr = is_paged_attention ? optional_input_tensors.at(1).value().buffer()->address() : 0; + uint32_t attn_mask_addr = use_attention_mask ? optional_input_tensors.at(2).value().buffer()->address() : 0; auto page_table_buffer = is_paged_attention ? optional_input_tensors.at(1).value().buffer() : nullptr; uint32_t page_table_stick_size = is_paged_attention ? page_table_buffer->aligned_page_size() : 0; uint32_t out_addr = out0_buffer->address(); @@ -681,7 +697,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( uint32_t cur_batch = i / num_cores_per_batch; uint32_t core_num_in_reduce = (num_cores_per_head == 0) ? 0 : i % num_cores_per_head; uint32_t core_num_in_output = i % num_cores_per_batch; - uint32_t cur_pos = use_cur_pos_tensor ? -1 : cur_pos_ids.at(cur_batch); + uint32_t cur_pos = (use_cur_pos_tensor || ! is_causal) ? -1 : cur_pos_ids.at(cur_batch); auto& reader_args = reader_args_by_core[core.x][core.y]; auto& writer_args = writer_args_by_core[core.x][core.y]; @@ -694,6 +710,7 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( reader_args[arg_idx++] = v_addr; reader_args[arg_idx++] = pos_addr; reader_args[arg_idx++] = page_table_addr; + reader_args[arg_idx++] = attn_mask_addr; reader_args[arg_idx++] = page_table_stick_size; reader_args[arg_idx++] = do_reduce; reader_args[arg_idx++] = do_output; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp index 9d619f9052d9..ea25388791ed 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/sdpa_decode_program_factory.hpp @@ -16,7 +16,9 @@ operation::ProgramWithCallbacks sdpa_decode_multi_core( const Tensor &input_tensor_v, std::optional cur_pos_tensor, std::optional page_table_tensor, + std::optional attn_mask, const Tensor &output_tensor, + bool is_causal, const std::vector& cur_pos_ids, std::optional scale, DeviceComputeKernelConfig compute_kernel_config, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp index cf5804d559a3..6035b3876815 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.cpp @@ -10,13 +10,19 @@ namespace { uint32_t get_chunk_size(uint32_t s) { - if (s <= 128) { - return 32; + /* + # find maximum power of 2 divisor of s + for i in range(1, s): + if s % (2**(i+1)) != 0: + break + */ + uint32_t i = 1; + for (; i < s; i++) { + if (s % (1 << (i + 1)) != 0) { + break; + } } - if (s <= 256) { - return 256; - } - return 512; + return std::min(512, 1 << i); } } // namespace @@ -27,6 +33,8 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, + const bool is_causal, + const std::optional attn_mask, const std::vector cur_pos, const std::optional cur_pos_tensor, std::optional scale, @@ -35,13 +43,15 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( std::optional compute_kernel_config) { auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - //uint32_t max_cur_pos = *std::max_element(cur_pos.begin(), cur_pos.end()); - uint32_t k_chunk_size = 512; //get_chunk_size(max_cur_pos + 1); + uint32_t s = input_tensor_k.get_shape()[-2]; + uint32_t k_chunk_size = get_chunk_size(s); if (program_config.has_value() && program_config.value().k_chunk_size > 0) { k_chunk_size = program_config.value().k_chunk_size; // assert chunk size must be power of 2 and multiple of 32 TT_FATAL((k_chunk_size & (k_chunk_size - 1)) == 0, "User provided k_chunk_size must be power of 2, got: {}", k_chunk_size); TT_FATAL(k_chunk_size % 32 == 0, "User provided k_chunk_size must be multiple of 32, got: {}", k_chunk_size); + } else { + TT_FATAL(k_chunk_size % 32 == 0, "Chunk size must be multiple of 32, but the maximum calculated k_chunk_size is: {}", k_chunk_size); } // get chunk size and then pass to sdpa decode as an attribute for prgm cache @@ -50,6 +60,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( return operation::run( ScaledDotProductAttentionDecode{ + .is_causal = is_causal, .cur_pos = cur_pos, .scale = scale, .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), @@ -58,7 +69,7 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( .k_chunk_size = k_chunk_size, .paged_attention = false}, {input_tensor_q, input_tensor_k, input_tensor_v}, - {cur_pos_tensor, std::nullopt}, + {cur_pos_tensor, std::nullopt, attn_mask}, {}, queue_id) .at(0); @@ -68,6 +79,8 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, + const bool is_causal, + const std::optional attn_mask, const std::vector cur_pos, const std::optional cur_pos_tensor, std::optional scale, @@ -79,6 +92,8 @@ ttnn::Tensor ExecuteScaledDotProductAttentionDecode::invoke( input_tensor_q, input_tensor_k, input_tensor_v, + is_causal, + attn_mask, cur_pos, cur_pos_tensor, scale, @@ -93,21 +108,25 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal, + const std::optional attn_mask, + const std::optional &cur_pos_tensor, std::optional scale, const std::optional &memory_config, std::optional program_config, std::optional compute_kernel_config) { auto arch = input_tensor_q.storage_type() == StorageType::DEVICE ? input_tensor_q.device()->arch() : ttnn::operations::experimental::auto_format::AutoFormat::GetDefaultDevice()->arch(); - //uint32_t max_cur_pos = *std::max_element(cur_pos.begin(), cur_pos.end()); - uint32_t k_chunk_size = 512; //get_chunk_size(max_cur_pos + 1); + uint32_t s = input_tensor_k.get_shape()[-2]; + uint32_t k_chunk_size = get_chunk_size(s); if (program_config.has_value() && program_config.value().k_chunk_size > 0) { k_chunk_size = program_config.value().k_chunk_size; // assert chunk size must be power of 2 and multiple of 32 TT_FATAL((k_chunk_size & (k_chunk_size - 1)) == 0, "User provided k_chunk_size must be power of 2, got: {}", k_chunk_size); TT_FATAL(k_chunk_size % 32 == 0, "User provided k_chunk_size must be multiple of 32, got: {}", k_chunk_size); + } else { + TT_FATAL(k_chunk_size % 32 == 0, "Chunk size must be multiple of 32, but the maximum calculated k_chunk_size is: {}", k_chunk_size); } // get chunk size and then pass to sdpa decode as an attribute for prgm cache @@ -116,6 +135,7 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( return operation::run( ScaledDotProductAttentionDecode{ + .is_causal = is_causal, .cur_pos = std::vector(), .scale = scale, .output_mem_config = memory_config.value_or(operation::DEFAULT_OUTPUT_MEMORY_CONFIG), @@ -124,7 +144,7 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( .k_chunk_size = k_chunk_size, .paged_attention = true}, {input_tensor_q, input_tensor_k, input_tensor_v}, - {cur_pos_tensor, page_table_tensor}, + {cur_pos_tensor, page_table_tensor, attn_mask}, {}, queue_id) .at(0); @@ -134,8 +154,10 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal, + const std::optional attn_mask, + const std::optional &cur_pos_tensor, std::optional scale, const std::optional &memory_config, std::optional program_config, @@ -145,8 +167,10 @@ ttnn::Tensor ExecutePagedScaledDotProductAttentionDecode::invoke( input_tensor_q, input_tensor_k, input_tensor_v, - cur_pos_tensor, page_table_tensor, + is_causal, + attn_mask, + cur_pos_tensor, scale, memory_config, program_config, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp index b86a7696288a..7b8eee9ce166 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode.hpp @@ -17,7 +17,9 @@ struct ExecuteScaledDotProductAttentionDecode { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, + const bool is_causal= true, + const std::optional attn_mask= std::nullopt, + const std::vector cur_pos= std::vector(), const std::optional cur_pos_tensor= std::nullopt, std::optional scale = std::nullopt, const std::optional &memory_config = std::nullopt, @@ -28,7 +30,9 @@ struct ExecuteScaledDotProductAttentionDecode { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const std::vector cur_pos, + const bool is_causal= true, + const std::optional attn_mask= std::nullopt, + const std::vector cur_pos= std::vector(), const std::optional cur_pos_tensor= std::nullopt, std::optional scale = std::nullopt, const std::optional &memory_config = std::nullopt, @@ -42,8 +46,10 @@ struct ExecutePagedScaledDotProductAttentionDecode { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal= true, + const std::optional attn_mask= std::nullopt, + const std::optional &cur_pos_tensor= std::nullopt, std::optional scale = std::nullopt, const std::optional &memory_config = std::nullopt, std::optional program_config = std::nullopt, @@ -53,8 +59,10 @@ struct ExecutePagedScaledDotProductAttentionDecode { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal= true, + const std::optional attn_mask= std::nullopt, + const std::optional &cur_pos_tensor= std::nullopt, std::optional scale = std::nullopt, const std::optional &memory_config = std::nullopt, std::optional program_config = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp index 1215563ff4a2..8026d56385b3 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/sdpa_decode_pybind.cpp @@ -26,11 +26,12 @@ void py_bind_sdpa_decode(py::module &module) { input_tensor_q (ttnn.Tensor): the input tensor [1 x b x nh x dh] input_tensor_k (ttnn.Tensor): the input tensor [b x nkv x s x dh] input_tensor_v (ttnn.Tensor): the input tensor [b x nkv x s x dh] - cur_pos (List of int): list of integers of length b. - Keyword args: + is_causal (bool): whether the attention is is_causal. Defaults to `True`. + attn_mask (ttnn.Tensor, optional): the input tensor [b x 1 x s x s]. Defaults to `None`. + cur_pos (List of int, optional): list of integers of length b. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. queue_id (int, optional): command queue id. Defaults to `0`. cur_pos_tensor (ttnn.Tensor, optional): [b] tensor of integers of length b. Defaults to `None`. @@ -57,6 +58,8 @@ void py_bind_sdpa_decode(py::module &module) { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, + const bool is_causal, + const std::optional attn_mask, const std::vector cur_pos, const std::optional cur_pos_tensor, std::optional scale, @@ -69,6 +72,8 @@ void py_bind_sdpa_decode(py::module &module) { input_tensor_q, input_tensor_k, input_tensor_v, + is_causal, + attn_mask, cur_pos, cur_pos_tensor, scale, @@ -79,8 +84,10 @@ void py_bind_sdpa_decode(py::module &module) { py::arg("input_tensor_q").noconvert(), py::arg("input_tensor_k").noconvert(), py::arg("input_tensor_v").noconvert(), - py::arg("cur_pos").noconvert() = std::vector(), py::kw_only(), + py::arg("is_causal").noconvert() = true, + py::arg("attn_mask").noconvert() = std::nullopt, + py::arg("cur_pos").noconvert() = std::vector(), py::arg("cur_pos_tensor").noconvert() = std::nullopt, py::arg("scale").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt, @@ -100,8 +107,10 @@ void py_bind_sdpa_decode(py::module &module) { const ttnn::Tensor &input_tensor_q, const ttnn::Tensor &input_tensor_k, const ttnn::Tensor &input_tensor_v, - const ttnn::Tensor &cur_pos_tensor, const ttnn::Tensor &page_table_tensor, + const bool is_causal, + const std::optional attn_mask, + const std::optional &cur_pos_tensor, std::optional scale, const std::optional &memory_config, std::optional program_config, @@ -112,8 +121,10 @@ void py_bind_sdpa_decode(py::module &module) { input_tensor_q, input_tensor_k, input_tensor_v, - cur_pos_tensor, page_table_tensor, + is_causal, + attn_mask, + cur_pos_tensor, scale, memory_config, program_config, @@ -122,9 +133,11 @@ void py_bind_sdpa_decode(py::module &module) { py::arg("input_tensor_q").noconvert(), py::arg("input_tensor_k").noconvert(), py::arg("input_tensor_v").noconvert(), - py::arg("cur_pos_tensor").noconvert(), py::arg("page_table_tensor").noconvert(), py::kw_only(), + py::arg("is_causal").noconvert() = true, + py::arg("attn_mask").noconvert() = std::nullopt, + py::arg("cur_pos_tensor").noconvert() = std::nullopt, py::arg("scale").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt, py::arg("program_config").noconvert() = std::nullopt,