diff --git a/models/demos/t3000/falcon40b/tt/falcon_attention.py b/models/demos/t3000/falcon40b/tt/falcon_attention.py index aba44a9ef65..e968b002f69 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_attention.py +++ b/models/demos/t3000/falcon40b/tt/falcon_attention.py @@ -352,7 +352,6 @@ def fwd_prefill( query_layer, key_layer, value_layer, - attention_mask, is_causal=True, scale=self.scalar, program_config=self.model_config["SDPA_PROGCFG"], diff --git a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py index cba65a76bdf..dca04a6d2e4 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py @@ -405,7 +405,6 @@ def prefill_attn_mqa( query_layer, key_layer, value_layer, - attn_masks, is_causal=True, scale=self.scale, program_config=self.model_config["SDPA_PROGCFG"], diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index 446a97881f2..5efc7158d42 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -368,7 +368,6 @@ def forward_prefill(self, xs_11SH, attn_masks, rot_mats, transformation_mats, us q_heads_14SD, k_heads_11SD, v_heads_11SD, - attn_masks, is_causal=True, scale=self.scale, program_config=self.model_config["SDPA_PROGCFG"](seq_len), diff --git a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py index ec3a6d8f999..0888550e944 100644 --- a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py @@ -635,7 +635,6 @@ def prefill_attn_mqa( query_layer, key_layer, value_layer, - attn_masks, is_causal=True, scale=self.scale, ) diff --git a/models/demos/wormhole/llama31_8b/tt/llama_attention.py b/models/demos/wormhole/llama31_8b/tt/llama_attention.py index cf6dc2c012f..38f32247472 100644 --- a/models/demos/wormhole/llama31_8b/tt/llama_attention.py +++ b/models/demos/wormhole/llama31_8b/tt/llama_attention.py @@ -427,7 +427,6 @@ def forward_prefill(self, xs_11SH, attn_masks, rot_mats, transformation_mats, us q_heads_84SD, k_heads_K1SD, v_heads_V1SD, - attn_masks, is_causal=True, scale=self.scale, program_config=self.model_config["SDPA_PROGCFG"](seq_len), diff --git a/models/demos/wormhole/mistral7b/tt/mistral_attention.py b/models/demos/wormhole/mistral7b/tt/mistral_attention.py index 49d060e4b84..7fc4b72aa12 100644 --- a/models/demos/wormhole/mistral7b/tt/mistral_attention.py +++ b/models/demos/wormhole/mistral7b/tt/mistral_attention.py @@ -608,7 +608,6 @@ def forward_prefill(self, xs_11SH, attn_masks, rot_mats, transformation_mats, us q_heads_84SD, k_heads_K1SD, v_heads_V1SD, - attn_masks, is_causal=True, scale=self.scale, program_config=self.model_config["SDPA_PROGCFG"](seq_len), diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py index e88fbb5fb68..be5b730bf6f 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_scaled_dot_product_attention.py @@ -31,32 +31,28 @@ def run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype compute_kernel_config = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi2, math_approx_mode=True, - fp32_dest_acc_en=True, + fp32_dest_acc_en=False, packer_l1_acc=False, ) Q = torch.randn(b, nh, s, d) K = torch.randn(b, nkv, s, d) V = torch.randn(b, nkv, s, d) - attn_mask = torch.full((s, s), torch.finfo(torch.float32).min) - attn_mask = torch.triu(attn_mask, diagonal=1).expand(b, 1, -1, -1) # Print shapes of all inputs along with input names logger.debug(f"Q: {Q.shape}") logger.debug(f"K: {K.shape}") logger.debug(f"V: {V.shape}") - logger.debug(f"attn_mask: {attn_mask.shape}") tt_Q = ttnn.Tensor(Q, dtype).to(ttnn.TILE_LAYOUT).to(device) tt_K = ttnn.Tensor(K, dtype).to(ttnn.TILE_LAYOUT).to(device) tt_V = ttnn.Tensor(V, dtype).to(ttnn.TILE_LAYOUT).to(device) - tt_attn_mask = ttnn.Tensor(attn_mask, dtype).to(ttnn.TILE_LAYOUT).to(device) tt_back = ttnn.transformer.scaled_dot_product_attention( - tt_Q, tt_K, tt_V, tt_attn_mask, is_causal=True, program_config=program_config + tt_Q, tt_K, tt_V, is_causal=True, program_config=program_config, compute_kernel_config=compute_kernel_config ) tt_back = tt_back.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() - gt = torch.nn.functional.scaled_dot_product_attention(Q, K, V, attn_mask, is_causal=False) + gt = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True) out_pass, out_pcc = comp_pcc(gt, tt_back, 0.994) logger.debug(f"python vs pytorch: {out_pcc}") @@ -112,280 +108,3 @@ def test_sdpa_tt_with_program_cache(device, b, nh, nkv, s, d, q_chunk_size, k_ch run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype) assert device.num_program_cache_entries() == 1 - - -def nearest_n(x, n): - return ((x + n - 1) // n) * n - - -def nearest_pow_2(x): - if x < 1: - raise ValueError("x must be >= 1") - import math - - power = math.ceil(math.log2(x)) - return 1 << power - - -def num_to_corerange(x): - assert x < 8 or x % 8 == 0 - num_x = min(x, 8) - num_y = x // num_x - assert num_x * num_y == x - return ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(num_x - 1, num_y - 1), - ) - - -def get_chunk_size(s): - # Not sure if optimal - if s <= 32: - return 32 - if s <= 64: - return 64 - if s <= 128: - return 128 - if s <= 256: - return 256 - if s <= 2048: - return 512 - return 1024 - - -def run_test_sdpa_decode(device, b, nh, nkv, s, d, dtype): - padded_num_heads = nearest_pow_2(nearest_n(nh, n=32)) - torch.manual_seed(1234) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=False, - fp32_dest_acc_en=True, - packer_l1_acc=False, - ) - dram_memcfg = ttnn.DRAM_MEMORY_CONFIG - shard_grid = ttnn.CoreRangeSet({num_to_corerange(b)}) - shard_spec = ttnn.ShardSpec(shard_grid, (padded_num_heads, d), ttnn.ShardOrientation.ROW_MAJOR, False) - - height_sharded_memcfg = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) - - K = torch.randn(nkv, b, s, d) - V = torch.randn(nkv, b, s, d) - - tt_K = ttnn.as_tensor(K, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg) - tt_V = ttnn.as_tensor(V, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg) - - start_idx = 31 - - while start_idx < s: - scale = d**-0.5 - - k_chunk_size = get_chunk_size(start_idx) - program_config = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=device.compute_with_storage_grid_size(), - q_chunk_size=padded_num_heads, - k_chunk_size=k_chunk_size, - ) - - padded_layer_len = nearest_n(start_idx, n=k_chunk_size) - - # Test various sequence lengths - logger.debug(f"Testing with sequence length: {start_idx}") - logger.debug(f"Using chunk size: {k_chunk_size}") - logger.debug(f"Using padded layer length: {padded_layer_len}") - logger.debug(f"Using padded num heads: {padded_num_heads}") - - attn_mask = torch.zeros((1, b, padded_num_heads, padded_layer_len)) - # Assume all users are at same position - attn_mask[:, :, :, start_idx:] = torch.finfo(torch.float32).min - - Q = torch.randn(1, b, padded_num_heads, d) - - tt_Q = ttnn.as_tensor( - Q, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=height_sharded_memcfg - ) - - tt_attn_mask = ttnn.as_tensor( - attn_mask, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg - ) - - tt_back = ttnn.transformer.scaled_dot_product_attention( - tt_Q, - tt_K, - tt_V, - tt_attn_mask, - is_causal=False, - scale=scale, - program_config=program_config, - valid_seq_len=padded_layer_len, - compute_kernel_config=compute_kernel_config, - memory_config=height_sharded_memcfg, - ) - - tt_back = ttnn.to_torch(tt_back) - tt_back = tt_back[:, :, :nh, :] - - Q_slice = Q[:, :, :nh, :].permute(1, 2, 0, 3) # b, nh, 1, d - K_slice = K[:, :, :padded_layer_len, :].permute(1, 0, 2, 3) # nh, b, S, d - V_slice = V[:, :, :padded_layer_len, :].permute(1, 0, 2, 3) # nh, b, S, d - attn_mask_slice = attn_mask[:, :, :nh, :].permute(1, 2, 0, 3) # b, nh, 1, S - expect = torch.nn.functional.scaled_dot_product_attention( - Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False - ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) - - out_pass, out_pcc = comp_pcc(expect, tt_back, 0.99) - - logger.debug(f"python vs pytorch: {out_pcc}") - assert out_pass - - start_idx += 601 - - -@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") -@pytest.mark.parametrize( - "dtype", - [ttnn.bfloat8_b, ttnn.bfloat16], - ids=["bfp8", "bf16"], -) -@pytest.mark.parametrize( - "b, nh, nkv, s, d", - ( - [16, 8, 1, 8192, 128], # Llama2-70B - [32, 16, 1, 2048, 64], # Falcon-40B - [32, 4, 1, 8192, 128], # Mixtral - ), -) -def test_sdpa_decode(device, b, nh, nkv, s, d, dtype): - ttnn.device.DisablePersistentKernelCache() - run_test_sdpa_decode(device, b, nh, nkv, s, d, dtype) - - -def run_test_sdpa_decode_single_iter(device, b, nh, nkv, s, d, dtype): - padded_num_heads = nearest_pow_2(nearest_n(nh, n=32)) - torch.manual_seed(1234) - - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=ttnn.MathFidelity.HiFi4, - math_approx_mode=False, - fp32_dest_acc_en=True, - packer_l1_acc=False, - ) - dram_memcfg = ttnn.DRAM_MEMORY_CONFIG - shard_grid = ttnn.CoreRangeSet({num_to_corerange(b)}) - shard_spec = ttnn.ShardSpec(shard_grid, (padded_num_heads, d), ttnn.ShardOrientation.ROW_MAJOR, False) - - height_sharded_memcfg = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, shard_spec) - - K = torch.randn(nkv, b, s, d) - V = torch.randn(nkv, b, s, d) - - tt_K = ttnn.as_tensor(K, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg) - tt_V = ttnn.as_tensor(V, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg) - - start_idx = s // 2 - scale = d**-0.5 - - k_chunk_size = get_chunk_size(start_idx) - program_config = ttnn.SDPAProgramConfig( - compute_with_storage_grid_size=device.compute_with_storage_grid_size(), - q_chunk_size=padded_num_heads, - k_chunk_size=k_chunk_size, - ) - - padded_layer_len = nearest_n(start_idx, n=k_chunk_size) - - # Test various sequence lengths - logger.debug(f"Testing with sequence length: {start_idx}") - logger.debug(f"Using chunk size: {k_chunk_size}") - logger.debug(f"Using padded layer length: {padded_layer_len}") - logger.debug(f"Using padded num heads: {padded_num_heads}") - - attn_mask = torch.zeros((1, b, padded_num_heads, padded_layer_len)) - # Assume all users are at same position - attn_mask[:, :, :, start_idx:] = torch.finfo(torch.float32).min - - Q = torch.randn(1, b, padded_num_heads, d) - - tt_Q = ttnn.as_tensor(Q, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=height_sharded_memcfg) - - tt_attn_mask = ttnn.as_tensor( - attn_mask, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg - ) - - tt_back = ttnn.transformer.scaled_dot_product_attention( - tt_Q, - tt_K, - tt_V, - tt_attn_mask, - is_causal=False, - scale=scale, - program_config=program_config, - valid_seq_len=padded_layer_len, - compute_kernel_config=compute_kernel_config, - memory_config=height_sharded_memcfg, - ) - - tt_back = ttnn.to_torch(tt_back) - tt_back = tt_back[:, :, :nh, :] - - Q_slice = Q[:, :, :nh, :].permute(1, 2, 0, 3) # b, nh, 1, d - K_slice = K[:, :, :padded_layer_len, :].permute(1, 0, 2, 3) # nh, b, S, d - V_slice = V[:, :, :padded_layer_len, :].permute(1, 0, 2, 3) # nh, b, S, d - attn_mask_slice = attn_mask[:, :, :nh, :].permute(1, 2, 0, 3) # b, nh, 1, S - expect = torch.nn.functional.scaled_dot_product_attention( - Q_slice, K_slice, V_slice, attn_mask_slice, scale=scale, is_causal=False - ) # b, nh, 1, d - expect = expect.squeeze().unsqueeze(0) - - out_pass, out_pcc = comp_pcc(expect, tt_back, 0.99) - - logger.debug(f"python vs pytorch: {out_pcc}") - assert out_pass - - -@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs") -@pytest.mark.parametrize( - "dtype", - [ttnn.bfloat8_b, ttnn.bfloat16], - ids=["bfp8", "bf16"], -) -@pytest.mark.parametrize( - "b, nh, nkv, s, d", - ([16, 8, 1, 8192, 128],), # Llama2-70B -) -def test_sdpa_decode_program_cache(device, b, nh, nkv, s, d, dtype, use_program_cache): - ttnn.device.DisablePersistentKernelCache() - - dummy_tensors = [] - for _ in range(2): - dummy_tensors.append( - ttnn.as_tensor( - torch.zeros(32, 32), - device=device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - ) - ) - dummy_tensors.append( - ttnn.as_tensor( - torch.zeros(1, 1, 32, 32 * 32), - device=device, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.BufferType.L1, - ttnn.ShardSpec( - ttnn.CoreRangeSet({num_to_corerange(32)}), - (32, 32), - ttnn.ShardOrientation.ROW_MAJOR, - False, - ), - ), - ) - ) - run_test_sdpa_decode_single_iter(device, b, nh, nkv, s, d, dtype) - - assert device.num_program_cache_entries() == 1 diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp index 9c026218b99..dbe87a5a944 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/reader_interleaved.cpp @@ -45,7 +45,6 @@ void kernel_main() { constexpr uint32_t cb_q_in = tt::CB::c_in0; constexpr uint32_t cb_k_in = tt::CB::c_in1; constexpr uint32_t cb_v_in = tt::CB::c_in2; - constexpr uint32_t cb_mask_in = tt::CB::c_in3; constexpr uint32_t onetile = 1; @@ -55,8 +54,6 @@ void kernel_main() { constexpr DataFormat k_data_format = get_dataformat(cb_k_in); constexpr uint32_t v_tile_bytes = get_tile_size(cb_v_in); constexpr DataFormat v_data_format = get_dataformat(cb_v_in); - constexpr uint32_t mask_tile_bytes = get_tile_size(cb_mask_in); - constexpr DataFormat mask_data_format = get_dataformat(cb_mask_in); constexpr uint32_t barrier_threshold = get_barrier_read_threshold(); @@ -80,12 +77,6 @@ void kernel_main() { .data_format = v_data_format }; - const InterleavedAddrGenFast mask_reader = { - .bank_base_address = mask_addr, - .page_size = mask_tile_bytes, - .data_format = mask_data_format - }; - uint32_t q_tile_id = 0; uint32_t k_tile_id = 0; uint32_t v_tile_id = 0; @@ -164,39 +155,6 @@ void kernel_main() { noc_async_read_barrier(); cb_push_back(cb_k_in, k_chunk_tiles); - - // Finding the diagonal is harder now that q_chunk_size and k_chunk_size can differ - // Q-range = [q_low, q_high) - // K-range = [k_low, k_high) - // does_overlap = not (q_low >= k_high or k_low >= q_high) - // Due to loop bounds, we should never have k_low >= q_high. Can simplify this conditional check - // Read mask chunk - if (!(q_low_idx >= k_high_idx)) { - cb_reserve_back(cb_mask_in, mask_chunk_tiles); - uint32_t mask_write_ptr = get_write_ptr(cb_mask_in); - barrier_count = 0; - mask_tile_id = mask_batch_offset + q_chunk * Sq_chunk_t * St /*row_offset*/ + k_chunk * Sk_chunk_t /*col_offset*/; - for (uint32_t row = 0; row < Sq_chunk_t; ++row) { - for (uint32_t col = 0; col < Sk_chunk_t; ++col) { - noc_async_read_tile(mask_tile_id, mask_reader, mask_write_ptr); - mask_tile_id += 1; - mask_write_ptr += mask_tile_bytes; - - if (++barrier_count == barrier_threshold) { - noc_async_read_barrier(); - barrier_count = 0; - } - } - // Strid along columns to get to next row - mask_tile_id -= Sk_chunk_t; - mask_tile_id += St; - } - noc_async_read_barrier(); - cb_push_back(cb_mask_in, mask_chunk_tiles); - } - - - v_tile_id = v_batch_offset + k_chunk * Sk_chunk_t * DHt; // Read V chunk cb_reserve_back(cb_v_in, k_chunk_tiles); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp index ca4e57a0855..9687c07eba3 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/dataflow/writer_interleaved.cpp @@ -11,6 +11,126 @@ constexpr uint32_t get_barrier_read_threshold() { return ((512 / num_readers) * (1024 + 128)) / tile_bytes; } +template +void copy_tile(uint64_t noc_read_addr_base, uint32_t q_write_ptr_base, uint32_t src_tile_id, uint32_t dst_tile_id) { + noc_async_read(noc_read_addr_base + src_tile_id*tile_bytes, q_write_ptr_base + dst_tile_id*tile_bytes, tile_bytes); +} + +template +void fill_tile(uint32_t cb_id, uint32_t tile_id, uint32_t val) { + if (val == 0){ + constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE; + uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE); + uint32_t write_addr = get_write_ptr(cb_id) + tile_id*tile_bytes; + volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast(write_addr); + + // Fill tile with zeros + for (uint32_t i = 0; i < num_zeros_reads; ++i) { + noc_async_read(zeros_noc_addr, write_addr, MEM_ZEROS_SIZE); + write_addr += MEM_ZEROS_SIZE; + } + noc_async_read_barrier(); + } + else { + // Fill 2 uint16 datums in each writes to optimize for performance + volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast(get_write_ptr(cb_id) + tile_id*tile_bytes); + constexpr int num_uint32_datums_tile = (32 * 32) / 2; + for (int k = 0; k < num_uint32_datums_tile; k++) { + ptr[k] = val; + } + } +} + +template +void fill_diagonal_tile(uint32_t cb_id, uint32_t tile_id, uint32_t partial_val) { + /* + We want to fill cur_pos_in_tile + 1 to the end + */ + + fill_tile(cb_id, tile_id, 0); + + // DPRINT << "Fill partial tile" << ENDL(); + const uint16_t datum_val = partial_val>>16; + volatile tt_l1_ptr uint16_t* uint16_ptr = reinterpret_cast(get_write_ptr(cb_id) + tile_id*tile_bytes); + volatile tt_l1_ptr uint32_t* uint32_ptr = reinterpret_cast(get_write_ptr(cb_id) + tile_id*tile_bytes); + + constexpr uint32_t uint16_datums_per_face_row = 16; + constexpr uint32_t uint32_datums_per_face_row = 8; + constexpr uint32_t uint32_datums_per_face = (16 * 16) / 2; + // Fill diagonal faces with diagonal -inf + for (uint32_t k = 0; k < 4; k+=3) { + uint32_t uint16_face_idx = k << 8; + uint32_t uint32_face_idx = k << 7; + for (uint32_t r = 0; r < uint16_datums_per_face_row; ++r) { + const uint32_t col_start = r+1; + const uint32_t col_start_uint32 = (col_start + 1) >> 1; + if ((col_start) % 2 == 1) { + uint16_ptr[uint16_face_idx + r * uint16_datums_per_face_row + col_start] = datum_val; + } + for (uint32_t c = col_start_uint32; c < uint32_datums_per_face_row; ++c) { + uint32_ptr[uint32_face_idx + r * uint32_datums_per_face_row + c] = partial_val; + } + } + } + + // Fill face 1 with full -inf + uint32_t uint32_face_idx = 1 << 7; + for (uint32_t j = 0; j < uint32_datums_per_face; j++) { + uint32_ptr[uint32_datums_per_face + j] = partial_val; + } +} + +template +void generate_mask(uint32_t Sq_chunk_t, uint32_t Sk_chunk_t, uint32_t q_chunk, uint32_t k_chunk) { + uint32_t mask_size_tiles = Sq_chunk_t * Sk_chunk_t; + constexpr uint32_t NEG_INF = 0xFF80FF80; // TODO: Make sure this is -inf + cb_reserve_back(cb_mask_in, mask_size_tiles); + + uint32_t write_ptr_base = get_write_ptr(cb_mask_in); + uint64_t noc_write_addr_base = get_noc_addr(write_ptr_base); + constexpr uint32_t tile_bytes = get_tile_size(cb_mask_in); + + int zero_tile_idx = -1; + int inf_tile_idx = -1; + int diag_tile_idx = -1; + + // TODO: cache indices of prepared tiles + for (uint32_t q_tile = 0; q_tile < Sq_chunk_t; ++q_tile) { + for (uint32_t k_tile = 0; k_tile < Sk_chunk_t; ++k_tile) { + uint32_t in_mask_tile_id = q_tile * Sk_chunk_t + k_tile; + uint32_t global_q_tile = Sq_chunk_t * q_chunk + q_tile; + uint32_t global_k_tile = Sk_chunk_t * k_chunk + k_tile; + + if (global_k_tile < global_q_tile) { + if (zero_tile_idx == -1) { + fill_tile(cb_mask_in, in_mask_tile_id, 0); + zero_tile_idx = in_mask_tile_id; + } else { + copy_tile(noc_write_addr_base, write_ptr_base, zero_tile_idx, in_mask_tile_id); + } + } + else if (global_k_tile == global_q_tile) { + if (diag_tile_idx == -1) { + fill_diagonal_tile(cb_mask_in, in_mask_tile_id, NEG_INF); + diag_tile_idx = in_mask_tile_id; + } else { + copy_tile(noc_write_addr_base, write_ptr_base, diag_tile_idx, in_mask_tile_id); + } + } + else { + if (inf_tile_idx == -1) { + fill_tile(cb_mask_in, in_mask_tile_id, NEG_INF); + inf_tile_idx = in_mask_tile_id; + } else { + copy_tile(noc_write_addr_base, write_ptr_base, inf_tile_idx, in_mask_tile_id); + } + } + } + } + noc_async_read_barrier(); + cb_push_back(cb_mask_in, mask_size_tiles); +} + void kernel_main() { constexpr uint32_t B = get_compile_time_arg_val(0); constexpr uint32_t NQH = get_compile_time_arg_val(1); @@ -36,10 +156,12 @@ void kernel_main() { const uint32_t q_chunks_per_core = local_q_end - local_q_start; + constexpr uint32_t mask_chunk_tiles = Sq_chunk_t * Sk_chunk_t; constexpr uint32_t out_chunk_tiles = Sq_chunk_t * DHt; constexpr bool is_dram = true; constexpr uint32_t cb_out = tt::CB::c_out0; + constexpr uint32_t cb_mask_in = tt::CB::c_in3; constexpr uint32_t tile_bytes = get_tile_size(cb_out); constexpr DataFormat data_format = get_dataformat(cb_out); @@ -82,6 +204,23 @@ void kernel_main() { uint32_t q_chunk_offset = q_chunk * Sq_chunk_t * DHt; out_tile_id = q_batch_offset + q_head_offset + q_chunk_offset; + const uint32_t q_low_idx = q_chunk * Sq_chunk_t; // This is the sequence index of the first tile of this chunk + const uint32_t q_high_idx = q_low_idx + Sq_chunk_t; + + for (uint32_t k_chunk = 0; (k_chunk * Sk_chunk_t) < q_high_idx; ++k_chunk) { + const uint32_t k_low_idx = k_chunk * Sk_chunk_t; + const uint32_t k_high_idx = k_low_idx + Sk_chunk_t; + // Finding the diagonal is harder now that q_chunk_size and k_chunk_size can differ + // Q-range = [q_low, q_high) + // K-range = [k_low, k_high) + // does_overlap = not (q_low >= k_high or k_low >= q_high) + // Due to loop bounds, we should never have k_low >= q_high. Can simplify this conditional check + // Read mask chunk + if (!(q_low_idx >= k_high_idx)) { + generate_mask(Sq_chunk_t, Sk_chunk_t, q_chunk, k_chunk); + } + } + // Wait for compute to deliver output chunk cb_wait_front(cb_out, out_chunk_tiles); barrier_count = 0; diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp index 5eea68f336e..ff2b36cc2f9 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_op.cpp @@ -12,61 +12,62 @@ namespace ttnn::operations::transformer { void ScaledDotProductAttention::validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const { - TT_FATAL(input_tensors.size() == 3 and optional_input_tensors.size() == 1, "Must have 3 input tensors and mask"); + TT_FATAL(input_tensors.size() == 3 and optional_input_tensors.size() == 1, "Must have 3 input tensors and optional mask"); for (auto& input_tensor : input_tensors) { - TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device!"); - TT_FATAL(input_tensor.buffer() != nullptr, "Operands to SDPA need to be allocated in buffers on device!"); + TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device"); + TT_FATAL(input_tensor.buffer() != nullptr, "Operands to SDPA need to be allocated in buffers on device"); TT_FATAL((input_tensor.get_layout() == Layout::TILE), "Inputs to SDPA must be tilized"); TT_FATAL(input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::BFLOAT8_B); } - auto mask = optional_input_tensors.at(0).value(); - TT_FATAL(mask.storage_type() == StorageType::DEVICE, "Operands to SDPA need to be on device!"); - TT_FATAL(input_tensors.at(0).device() == mask.device()); - TT_FATAL(mask.get_layout() == Layout::TILE); - TT_FATAL(mask.get_dtype() == DataType::BFLOAT16 || mask.get_dtype() == DataType::BFLOAT8_B); - - TT_FATAL(mask.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM); + const auto& mask_option = optional_input_tensors.at(0); + if (mask_option.has_value()){ + TT_FATAL(!this->is_causal, "Causal SDPA does not take mask as input"); + auto mask = optional_input_tensors.at(0).value(); + TT_FATAL(mask.storage_type() == StorageType::DEVICE, "When mask is provided to SDPA, the tensor must be on device"); + TT_FATAL(input_tensors.at(0).device() == mask.device(), "When mask is provided to SDPA, it must be on the same device as the input tensors"); + TT_FATAL(mask.get_layout() == Layout::TILE, "When mask is provided to SDPA, it must be tilized"); + TT_FATAL(mask.get_dtype() == DataType::BFLOAT16 || mask.get_dtype() == DataType::BFLOAT8_B, "When mask is provided to SDPA, it must be in BF16 or BFP8 dataformat"); + TT_FATAL(input_tensors.at(0).get_dtype() == mask_option.value().get_dtype(), "When mask is provided to SDPA, it must have the same dataformat as the input tensors"); + + TT_FATAL(mask.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM, "When mask is provided to SDPA, it must be in DRAM"); + } const auto q_shape = input_tensors.at(0).get_legacy_shape(); const auto k_shape = input_tensors.at(1).get_legacy_shape(); const auto v_shape = input_tensors.at(2).get_legacy_shape(); - const auto mask_shape = mask.get_legacy_shape(); // assert all dataformats are the same TT_FATAL( input_tensors.at(0).get_dtype() == input_tensors.at(1).get_dtype() && - input_tensors.at(0).get_dtype() == input_tensors.at(2).get_dtype() && - input_tensors.at(0).get_dtype() == mask.get_dtype()); + input_tensors.at(0).get_dtype() == input_tensors.at(2).get_dtype(), "All inputs to SDPA must have the same dataformat"); if (this->is_causal) { // All inputs must be in DRAM for (auto& input_tensor : input_tensors) { - TT_FATAL(input_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM); + TT_FATAL(input_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM, "All inputs to causal SDPA must be in DRAM"); } // Check sequence lengths - TT_FATAL(q_shape[-2] == k_shape[-2] && q_shape[-2] == v_shape[-2]); - TT_FATAL(q_shape[-2] == mask_shape[-2] && q_shape[-2] == mask_shape[-1]); + TT_FATAL(q_shape[-2] == k_shape[-2] && q_shape[-2] == v_shape[-2], "Q, K, V sequence dim must match. Got Q: {}, K: {}, V: {}", q_shape[-2], k_shape[-2], v_shape[-2]); // Check batch size - TT_FATAL(q_shape[-4] == k_shape[-4] && q_shape[-4] == v_shape[-4]); - TT_FATAL(q_shape[-4] == mask_shape[-4]); + TT_FATAL(q_shape[-4] == k_shape[-4] && q_shape[-4] == v_shape[-4], "Q, K, V batch dim must match. Got Q: {}, K: {}, V: {}", q_shape[-4], k_shape[-4], v_shape[-4]); // Check hidden size - TT_FATAL(q_shape[-1] == k_shape[-1] && q_shape[-1] == v_shape[-1]); + TT_FATAL(q_shape[-1] == k_shape[-1] && q_shape[-1] == v_shape[-1], "Q, K, V hidden dim must match. Got Q: {}, K: {}, V: {}", q_shape[-1], k_shape[-1], v_shape[-1]); // Check kv heads - TT_FATAL(k_shape[-3] == v_shape[-3]); + TT_FATAL(k_shape[-3] == v_shape[-3], "K, V heads dim must match. Got K: {}, V: {}", k_shape[-3], v_shape[-3]); // Check qkv heads - TT_FATAL(q_shape[-3] >= k_shape[-3]); + TT_FATAL(q_shape[-3] >= k_shape[-3], "Q heads must be >= K heads. Got Q: {}, K: {}", q_shape[-3], k_shape[-3]); - TT_FATAL(mask_shape[-3] == 1); - - TT_FATAL(this->output_mem_config.buffer_type == tt::tt_metal::BufferType::DRAM); + TT_FATAL(this->output_mem_config.buffer_type == tt::tt_metal::BufferType::DRAM, "Output must be in DRAM"); } else { + const auto mask_shape = mask_option.value().get_legacy_shape(); + // Input 0 must be sharded by height. All other inputs must be in DRAM. const auto Q_memcfg = input_tensors.at(0).memory_config(); TT_FATAL(input_tensors.at(0).is_sharded() == true); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp index 0780a9a70bc..0190b90a748 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/sdpa_program_factory.cpp @@ -107,7 +107,6 @@ operation::ProgramWithCallbacks sdpa_multi_core( auto k_buffer = input_tensor_k.buffer(); auto v_buffer = input_tensor_v.buffer(); auto mask_buffer = attn_mask.has_value() ? attn_mask.value().buffer() : nullptr; - TT_FATAL(mask_buffer != nullptr); auto out0_buffer = output_tensor.buffer(); @@ -351,7 +350,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( uint32_t q_tile_size = tt::tt_metal::detail::TileSize(q_df); uint32_t k_tile_size = tt::tt_metal::detail::TileSize(k_df); uint32_t v_tile_size = tt::tt_metal::detail::TileSize(v_df); - uint32_t mask_tile_size = attn_mask.has_value() ? tt::tt_metal::detail::TileSize(mask_df) : 0; + uint32_t mask_tile_size = tt::tt_metal::detail::TileSize(mask_df); uint32_t out_tile_size = tt::tt_metal::detail::TileSize(out_df); uint32_t scalar_tile_size = tt::tt_metal::detail::TileSize(scalar_df); uint32_t im_tile_size = tt::tt_metal::detail::TileSize(im_df); @@ -448,7 +447,7 @@ operation::ProgramWithCallbacks sdpa_multi_core( uint32_t q_addr = q_buffer->address(); uint32_t k_addr = k_buffer->address(); uint32_t v_addr = v_buffer->address(); - uint32_t mask_addr = mask_buffer->address(); + uint32_t mask_addr = attn_mask.has_value() ? mask_buffer->address() : 0; uint32_t out_addr = out0_buffer->address(); // Set reader rt args @@ -542,13 +541,12 @@ operation::ProgramWithCallbacks sdpa_multi_core( auto v_buffer = input_tensors.at(2).buffer(); auto mask_buffer = optional_input_tensors.at(0).has_value() ? optional_input_tensors.at(0).value().buffer() : nullptr; - TT_FATAL(mask_buffer != nullptr); auto out0_buffer = output_tensors.at(0).buffer(); uint32_t q_addr = q_buffer->address(); uint32_t k_addr = k_buffer->address(); uint32_t v_addr = v_buffer->address(); - uint32_t mask_addr = mask_buffer->address(); + uint32_t mask_addr = mask_buffer != nullptr ? mask_buffer->address() : 0; uint32_t out_addr = out0_buffer->address(); auto& reader_args_by_core = GetRuntimeArgs(program, reader_kernels_id); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp index 773df170383..e4e9d311945 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/sdpa_pybind.cpp @@ -63,8 +63,8 @@ void py_bind_sdpa(py::module &module) { py::arg("input_tensor_q").noconvert(), py::arg("input_tensor_k").noconvert(), py::arg("input_tensor_v").noconvert(), - py::arg("causal_mask").noconvert(), py::kw_only(), + py::arg("causal_mask").noconvert() = std::nullopt, py::arg("is_causal").noconvert() = true, py::arg("scale").noconvert() = std::nullopt, py::arg("memory_config").noconvert() = std::nullopt,