diff --git a/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp b/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp index 768639149a1..314328ac286 100644 --- a/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp +++ b/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp @@ -556,6 +556,8 @@ INSTANTIATE_TEST_SUITE_P( .in0_block_w = 2, .out_subblock_h = 1, .out_subblock_w = 1, + .out_block_h = 64, + .out_block_w = 2, .per_core_M = 64, .per_core_N = 2, .fuse_batch = true, diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index d66fb34f464..c411ab46631 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -958,6 +958,232 @@ def test_matmul_1d_tiny_tile( assert device.num_program_cache_entries() == 1 +def run_matmul_1d_multiple_output_blocks_per_core( + device, + m, + k, + n, + has_bias, + grid_size, + in_sharded, + out_sharded, + num_out_block_h, + num_out_block_w, + mcast_in0, + uneven_width, +): + if in_sharded or out_sharded: + fuse_batch = True + else: + fuse_batch = False + + if out_sharded and num_out_block_w > 1: + pytest.skip("out sharded not support multiple blocks on w dim") + + if not mcast_in0: + tmp = m + m = n + n = tmp + + in0_shape = [1, 1, m, k] + in1_shape = [1, 1, k, n] + bias_shape = [1, 1, n] + + num_cores = grid_size[0] * grid_size[1] + + if mcast_in0: + in0_block_w = k // num_cores // 32 + per_core_M = m // 32 + per_core_N = n // num_cores // 32 + uneven_width + else: + in0_block_w = k // 32 // 2 # test exracting shards + per_core_M = m // 32 // num_cores + per_core_N = n // 32 + out_block_h = per_core_M // num_out_block_h + out_block_w = per_core_N // num_out_block_w + out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w) + + logger.info(f"m: {m}") + logger.info(f"k: {k}") + logger.info(f"n: {n}") + logger.info(f"in0_block_w: {in0_block_w}") + logger.info(f"per_core_M: {per_core_M}") + logger.info(f"per_core_N: {per_core_N}") + logger.info(f"out_block_h: {out_block_h}") + logger.info(f"out_block_w: {out_block_w}") + logger.info(f"out_subblock_h: {out_subblock_h}") + logger.info(f"out_subblock_w: {out_subblock_w}") + + in0 = torch.randn(in0_shape).bfloat16().float() + in1 = torch.randn(in1_shape).bfloat16().float() + + if in_sharded: + if mcast_in0: + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in0_memory_config = ttnn.DRAM_MEMORY_CONFIG + in1_memory_config = ttnn.DRAM_MEMORY_CONFIG + in0_t = ttnn.from_torch( + in0, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in0_memory_config, + ) + in1_t = ttnn.from_torch( + in1, + dtype=ttnn.bfloat8_b, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in1_memory_config, + ) + + if has_bias: + bias = torch.randn(bias_shape).bfloat16().float() + bias_padded = bias.unsqueeze(2) + bias_padded = torch.nn.functional.pad(bias_padded, (0, 0, 0, 32 - bias_padded.size(2)), "constant", 0) + bias_t = ttnn.from_torch( + bias_padded, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=grid_size, + in0_block_w=in0_block_w, + out_subblock_h=out_subblock_h, + out_subblock_w=out_subblock_w, + out_block_h=out_block_h, + out_block_w=out_block_w, + per_core_M=per_core_M, + per_core_N=per_core_N, + fuse_batch=fuse_batch, + fused_activation=None, + mcast_in0=mcast_in0, + ) + + if is_grayskull(): + compute_kernel_config = ttnn.GrayskullComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + ) + else: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=True, + ) + if out_sharded: + if mcast_in0: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + else: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + else: + out_mem_config = ttnn.DRAM_MEMORY_CONFIG + + if has_bias: + output_t = ttnn.linear( + in0_t, + in1_t, + bias=bias_t, + program_config=program_config, + memory_config=out_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + ) + else: + output_t = ttnn.matmul( + in0_t, + in1_t, + program_config=program_config, + memory_config=out_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + ) + output_tensor = ttnn.to_torch(output_t) + pt_out = in0 @ in1 + if has_bias: + pt_out += bias + + assert_with_pcc(pt_out, output_tensor, 0.999) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("m", [256]) +@pytest.mark.parametrize("k", [1024]) +@pytest.mark.parametrize("n", [2048]) +@pytest.mark.parametrize("has_bias", [False]) +@pytest.mark.parametrize("grid_size", [(8, 2)]) +@pytest.mark.parametrize("in_sharded", [True, False]) +@pytest.mark.parametrize("out_sharded", [True, False]) +@pytest.mark.parametrize("num_out_block_h", [1, 2]) +@pytest.mark.parametrize("num_out_block_w", [1, 2]) +@pytest.mark.parametrize("mcast_in0", [True, False]) +@pytest.mark.parametrize("uneven_width", [0, 2]) +def test_matmul_1d_multiple_output_blocks_per_core( + device, + m, + k, + n, + has_bias, + grid_size, + in_sharded, + out_sharded, + num_out_block_h, + num_out_block_w, + mcast_in0, + uneven_width, + use_program_cache, +): + for _ in range(2): + run_matmul_1d_multiple_output_blocks_per_core( + device, + m, + k, + n, + has_bias, + grid_size, + in_sharded, + out_sharded, + num_out_block_h, + num_out_block_w, + mcast_in0, + uneven_width, + ) + # dummy tensor to change tensor alloc + dummy_shape = [1, 1, 32, 32] + py_dummy_tensor = torch.randn(dummy_shape) + tt_dummy_tensor = ttnn.from_torch( + py_dummy_tensor, + dtype=ttnn.DataType.BFLOAT16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + assert device.num_program_cache_entries() == 1 + + # fmt: off @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("m_size,k_size,n_size", [ diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index bf215230584..9060a439669 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -754,6 +754,8 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co .in0_block_w = conv_blocking_config.act_block_w_ntiles, .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, + .out_block_h = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), + .out_block_w = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), .per_core_M = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), .per_core_N = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), .fuse_batch = true, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp index c84a4d4825a..0532ae005cc 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp @@ -71,14 +71,17 @@ void kernel_main() { // In case we need to send multiple blocks per shard, in0 sharded cb is cb2 and we extract the sub-blocks to cb0 constexpr uint32_t shard_read_stride = shard_width_in_tiles * in0_single_tile_size_bytes; constexpr uint32_t shard_read_width = in0_single_tile_size_bytes * in0_block_w; + constexpr uint32_t shard_num_tiles = shard_width_in_tiles * shard_height_in_tiles; + constexpr uint32_t in0_tensor_next_h_dim_block_stride_bytes = + in0_tensor_next_h_dim_block_stride * in0_single_tile_size_bytes; - uint64_t noc_shard_read_start_addr = 0; + uint32_t noc_shard_read_start_addr = 0; if constexpr (extract_shard_sub_blocks) { constexpr uint32_t cb_id_in2 = 2; // in0 sharded cb if extract_shard_sub_blocks - noc_shard_read_start_addr = get_noc_addr(get_read_ptr(cb_id_in2)); + noc_shard_read_start_addr = get_read_ptr(cb_id_in2); } else { - cb_reserve_back(cb_id_in0, in0_block_num_tiles); - cb_push_back(cb_id_in0, in0_block_num_tiles); + cb_reserve_back(cb_id_in0, shard_num_tiles); + cb_push_back(cb_id_in0, shard_num_tiles); } #else constexpr DataFormat in0_data_format = get_dataformat(cb_id_in0); @@ -113,9 +116,15 @@ void kernel_main() { #endif for (uint32_t b = 0; b < batch; ++b) { +#ifdef IN0_SHARDED + uint32_t in0_tensor_current_h_dim_block_start_addr = noc_shard_read_start_addr; +#endif uint32_t in0_tensor_current_h_dim_block_tile_id = in0_tensor_start_tile_id; for (uint32_t bh = 0; bh < num_blocks_h_dim; ++bh) { for (uint32_t bw = 0; bw < num_blocks_w_dim; ++bw) { +#ifdef IN0_SHARDED + uint32_t in0_tensor_current_inner_dim_block_start_addr = in0_tensor_current_h_dim_block_start_addr; +#endif uint32_t in0_tensor_current_inner_dim_block_start_tile_id = in0_tensor_current_h_dim_block_tile_id; for (uint32_t block = 0; block < num_blocks_inner_dim; ++block) { if constexpr (fuse_op) { @@ -159,16 +168,16 @@ void kernel_main() { in0_start_address = l1_write_addr_in0; // copy start address of block, to be used for mcasting #endif - uint64_t noc_shard_read_addr = noc_shard_read_start_addr; - noc_shard_read_start_addr += shard_read_width; + uint64_t noc_shard_read_addr = get_noc_addr(in0_tensor_current_inner_dim_block_start_addr); - for (uint32_t i = 0; i < shard_height_in_tiles; i++) { + for (uint32_t i = 0; i < in0_block_h; i++) { noc_async_read(noc_shard_read_addr, l1_write_addr_in0, shard_read_width); l1_write_addr_in0 += shard_read_width; noc_shard_read_addr += shard_read_stride; } + in0_tensor_current_inner_dim_block_start_addr += shard_read_width; noc_async_read_barrier(); } #endif @@ -216,6 +225,9 @@ void kernel_main() { #endif } } +#ifdef IN0_SHARDED + in0_tensor_current_h_dim_block_start_addr += in0_tensor_next_h_dim_block_stride_bytes; +#endif in0_tensor_current_h_dim_block_tile_id += in0_tensor_next_h_dim_block_stride; } in0_tensor_start_tile_id += MtKt; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index a57ef49b74c..c4745856b80 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -306,6 +306,8 @@ MatmulProgramConfig create_matmul_1d_systolic_array_program_config( .in0_block_w = k_tiles_per_core, .out_subblock_h = out_subblock_h, .out_subblock_w = out_subblock_w, + .out_block_h = batch_and_m_tiles_per_core, + .out_block_w = n_tiles_per_core, .per_core_M = batch_and_m_tiles_per_core, .per_core_N = n_tiles_per_core, .fuse_batch = true, @@ -357,6 +359,8 @@ MatmulMultiCoreReuseMultiCast1DProgramConfig get_mcast_1d_config( .in0_block_w = in0_block_w, .out_subblock_h = out_subblock_h, .out_subblock_w = out_subblock_w, + .out_block_h = per_core_M, + .out_block_w = per_core_N, .per_core_M = per_core_M, .per_core_N = per_core_N, .fuse_batch = fuse_batch, @@ -701,6 +705,8 @@ MatmulProgramConfig get_matmul_program_config( .in0_block_w = in0_block_w, .out_subblock_h = out_subblock_h, .out_subblock_w = out_subblock_w, + .out_block_h = per_core_M, + .out_block_w = per_core_N, .per_core_M = per_core_M, .per_core_N = per_core_N, .fuse_batch = true, @@ -1182,6 +1188,26 @@ void Matmul::validate( // TODO: For 1D and 2D mcasts, we don't check if tensor is single core or single row/col // We can uplift these variants to skip mcasting to support single core (1D) or single row/col (2D) if constexpr (std::is_same_v) { + TT_FATAL( + program_config.per_core_M % program_config.out_block_h == 0, + "Error: incompatible values {} and {}", + program_config.per_core_M, + program_config.out_block_h); + TT_FATAL( + program_config.per_core_N % program_config.out_block_w == 0, + "Error: incompatible values {} and {}", + program_config.per_core_N, + program_config.out_block_w); + TT_FATAL( + program_config.out_block_h % program_config.out_subblock_h == 0, + "Error: incompatible values {} and {}", + program_config.out_block_h, + program_config.out_subblock_h); + TT_FATAL( + program_config.out_block_w % program_config.out_subblock_w == 0, + "Error: incompatible values {} and {}", + program_config.out_block_w, + program_config.out_subblock_w); TT_FATAL( !(program_config.mcast_in0 && program_config.gather_in0), "Matmul1D does not support mcast_in0 and gather_in0 at the same time."); @@ -1806,6 +1832,8 @@ operation::ProgramWithCallbacks Matmul::create_program( program_config.in0_block_w, program_config.out_subblock_h, program_config.out_subblock_w, + program_config.out_block_h, + program_config.out_block_w, program_config.per_core_M, program_config.per_core_N, program_config.fuse_batch, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index 4eea7a50f19..a4b41cb6519 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -43,6 +43,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, @@ -130,6 +132,8 @@ struct MatmulMultiCoreReuseMultiCast1DProgramConfig { std::size_t in0_block_w; std::size_t out_subblock_h; std::size_t out_subblock_w; + std::size_t out_block_h; + std::size_t out_block_w; std::size_t per_core_M; std::size_t per_core_N; bool fuse_batch; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp index b5352ab0d45..4e299855332 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp @@ -71,6 +71,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, std::optional fused_activation, @@ -113,7 +115,16 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); - uint32_t in0_block_tiles = per_core_M * in0_block_w; + bool do_not_inplace_interm0_out_CB = output_is_sharded && (per_core_M != out_block_h); + + uint32_t in0_block_h = out_block_h; + uint32_t in1_block_w = out_block_w; + uint32_t in0_num_blocks_y = per_core_M / out_block_h; + uint32_t in1_num_blocks_x = per_core_N / out_block_w; + uint32_t out_num_blocks_x = in1_num_blocks_x; + uint32_t out_num_blocks_y = in0_num_blocks_y; + + uint32_t in0_block_tiles = in0_block_h * in0_block_w; uint32_t in0_CB_tiles = in0_block_tiles; if (B * num_blocks > 1) { in0_CB_tiles = in0_CB_tiles * 2; // double buffer @@ -131,7 +142,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in2_CB_tiles = in2_block_tiles; uint32_t in2_CB_size = in2_CB_tiles * in0_single_tile_size; - uint32_t in1_block_tiles = per_core_N * in0_block_w; + uint32_t in1_block_tiles = out_block_w * in0_block_w; uint32_t in1_CB_tiles = in1_block_tiles; if (B * num_blocks > 1) { in1_CB_tiles = in1_CB_tiles * 2; // double buffer @@ -143,12 +154,17 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in1_CB_size = in1_CB_tiles * in1_single_tile_size; - uint32_t out_block_tiles = per_core_M * per_core_N; + uint32_t out_block_tiles = out_block_h * out_block_w; + uint32_t out_shard_tiles = per_core_M * per_core_N; uint32_t out_CB_tiles = out_block_tiles; // No double buffer + if (output_is_sharded) { + out_CB_tiles = out_shard_tiles; + } uint32_t out_CB_size = out_CB_tiles * output_single_tile_size; - uint32_t interm0_CB_size = out_CB_tiles * interm0_single_tile_size; + uint32_t interm0_CB_tiles = out_block_tiles; // No double buffer + uint32_t interm0_CB_size = interm0_CB_tiles * interm0_single_tile_size; - uint32_t in3_block_tiles = per_core_N; + uint32_t in3_block_tiles = out_block_w; uint32_t in3_CB_tiles = in3_block_tiles; // No double buffer uint32_t in3_CB_size = in3_CB_tiles * bias_single_tile_size; @@ -252,7 +268,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( } bool out_is_dram = out_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - uint32_t in0_num_subblocks = (per_core_M / out_subblock_h); + uint32_t in0_num_subblocks = (out_block_h / out_subblock_h); uint32_t in0_block_num_tiles = out_subblock_h * in0_block_w * in0_num_subblocks; std::vector in0_sender_compile_time_args; @@ -264,9 +280,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles (std::uint32_t)in0_block_num_tiles * in0_single_tile_size, // in0_block_size_bytes // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // num_blocks_x - (std::uint32_t)1, // num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // num_blocks_x + (std::uint32_t)out_num_blocks_y, // num_blocks_y // in0 mcast args (std::uint32_t)in0_mcast_sender_semaphore_id, (std::uint32_t)in0_mcast_receiver_semaphore_id, @@ -278,7 +294,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)(in0_shard_width_in_tiles), (std::uint32_t)(in0_shard_height_in_tiles), (std::uint32_t)(in0_block_w), - (std::uint32_t)per_core_M, // in0_block_h + (std::uint32_t)in0_block_h, // in0_block_h // batch args (std::uint32_t)B // batch @@ -289,21 +305,21 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)in0_is_dram, // in0 tensor args - (std::uint32_t)1, // in0_tensor_stride_w - (std::uint32_t)K, // in0_tensor_stride_h - (std::uint32_t)in0_block_w, // in0_tensor_next_block_stride - (std::uint32_t)K * per_core_M, // in0_tensor_next_h_dim_block_stride + (std::uint32_t)1, // in0_tensor_stride_w + (std::uint32_t)K, // in0_tensor_stride_h + (std::uint32_t)in0_block_w, // in0_tensor_next_block_stride + (std::uint32_t)K * in0_block_h, // in0_tensor_next_h_dim_block_stride // in0 block args - (std::uint32_t)in0_block_w, // in0_block_w - (std::uint32_t)per_core_M, // in0_block_h - (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles - (std::uint32_t) false, // extract_shard_sub_blocks (not used for interleaved) - (std::uint32_t)0, // shard_width_in_tiles (not used for interleaved) - (std::uint32_t)0, // shard_height_in_tiles (not used for interleaved) + (std::uint32_t)in0_block_w, // in0_block_w + (std::uint32_t)in0_block_h, // in0_block_h + (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles + (std::uint32_t)false, // extract_shard_sub_blocks (not used for interleaved) + (std::uint32_t)0, // shard_width_in_tiles (not used for interleaved) + (std::uint32_t)0, // shard_height_in_tiles (not used for interleaved) // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // num_blocks_x - (std::uint32_t)1, // num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // num_blocks_x + (std::uint32_t)out_num_blocks_y, // num_blocks_y // in0 mcast args (std::uint32_t)in0_mcast_sender_semaphore_id, (std::uint32_t)in0_mcast_receiver_semaphore_id, @@ -326,15 +342,15 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)1, // in1_tensor_stride_w (std::uint32_t)N, // in1_tensor_stride_h (std::uint32_t)in0_block_w * N, // in1_tensor_next_block_stride - (std::uint32_t)per_core_N, // in1_tensor_next_w_dim_block_stride + (std::uint32_t)in1_block_w, // in1_tensor_next_w_dim_block_stride // in1 block args - (std::uint32_t)per_core_N, // in1_block_w - (std::uint32_t)in0_block_w, // in1_block_h - (std::uint32_t)per_core_N * in0_block_w, // in1_block_num_tiles + (std::uint32_t)in1_block_w, // in1_block_w + (std::uint32_t)in0_block_w, // in1_block_h + (std::uint32_t)in1_block_w * in0_block_w, // in1_block_num_tiles // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in1 mcast args (std::uint32_t)0, (std::uint32_t)0, @@ -351,8 +367,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)N, // out_tensor_stride_h (std::uint32_t)out_subblock_w, // out_tensor_next_subblock_stride_w (std::uint32_t)out_subblock_h * N, // out_tensor_next_subblock_stride_h - (std::uint32_t)per_core_N, // out_tensor_next_w_dim_block_stride - (std::uint32_t)per_core_M * N, // out_tensor_next_h_dim_block_stride + (std::uint32_t)out_block_w, // out_tensor_next_w_dim_block_stride + (std::uint32_t)out_block_h * N, // out_tensor_next_h_dim_block_stride // out subblock args (std::uint32_t)out_subblock_w, // out_subblock_w (std::uint32_t)out_subblock_h, // out_subblock_h @@ -372,11 +388,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0( std::vector in0_receiver_compile_time_args = { // in0 block args - (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles + (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in0 mcast args (std::uint32_t)in0_mcast_sender_semaphore_id, (std::uint32_t)in0_mcast_receiver_semaphore_id, @@ -518,7 +534,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in0_subblock_num_tiles = out_subblock_h * in0_block_w; - uint32_t in1_num_subblocks = (per_core_N / out_subblock_w); + uint32_t in1_num_subblocks = (out_block_w / out_subblock_w); uint32_t in1_block_num_tiles = out_subblock_w * in0_block_w * in1_num_subblocks; uint32_t in1_per_core_w = out_subblock_w * in1_num_subblocks; @@ -534,9 +550,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0( in1_block_num_tiles, // in1_block_num_tiles in1_per_core_w, // in1_per_core_w - num_blocks, // num_blocks - 1, // out_num_blocks_x - 1, // out_num_blocks_y + num_blocks, // num_blocks + out_num_blocks_x, // out_num_blocks_x + out_num_blocks_y, // out_num_blocks_y out_subblock_h, // out_subblock_h out_subblock_w, // out_subblock_w @@ -627,7 +643,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(0, {{output_cb_index, output_data_format}}); - if ((interm0_data_format != output_data_format) || (untilize_out && (in1_num_subblocks > 1))) { + if (do_not_inplace_interm0_out_CB || (interm0_data_format != output_data_format) || + (untilize_out && (in1_num_subblocks > 1))) { // output std::map output_cb_data_format_spec{ {output_cb_index, output_data_format}, @@ -697,20 +714,22 @@ operation::ProgramWithCallbacks create_program_mcast_in0( } // Parameters for last row, col, or block - uint32_t last_block_h = M % per_core_M == 0 ? per_core_M : M % per_core_M; - uint32_t last_block_w = N % per_core_N == 0 ? per_core_N : N % per_core_N; - uint32_t last_block_num_nonzero_subblocks_h = (last_block_h - 1) / out_subblock_h + 1; - uint32_t last_block_num_nonzero_subblocks_w = (last_block_w - 1) / out_subblock_w + 1; + uint32_t last_per_core_M = M % per_core_M == 0 ? per_core_M : M % per_core_M; + uint32_t last_per_core_N = N % per_core_N == 0 ? per_core_N : N % per_core_N; + uint32_t last_out_block_h = last_per_core_M % out_block_h == 0 ? out_block_h : last_per_core_M % out_block_h; + uint32_t last_out_block_w = last_per_core_N % out_block_w == 0 ? out_block_w : last_per_core_N % out_block_w; + uint32_t last_block_num_nonzero_subblocks_h = (last_out_block_h - 1) / out_subblock_h + 1; + uint32_t last_block_num_nonzero_subblocks_w = (last_out_block_w - 1) / out_subblock_w + 1; uint32_t last_subblock_of_last_block_h = - last_block_h % out_subblock_h == 0 ? out_subblock_h : last_block_h % out_subblock_h; + last_out_block_h % out_subblock_h == 0 ? out_subblock_h : last_out_block_h % out_subblock_h; uint32_t last_subblock_of_last_block_w = - last_block_w % out_subblock_w == 0 ? out_subblock_w : last_block_w % out_subblock_w; + last_out_block_w % out_subblock_w == 0 ? out_subblock_w : last_out_block_w % out_subblock_w; uint32_t last_block_padded_subblock_tiles_addr_skip = output_single_tile_size * (out_subblock_w - last_subblock_of_last_block_w); uint32_t last_block_padded_block_tiles_w_skip = - (out_subblock_w * out_subblock_h) * (per_core_N / out_subblock_w - last_block_num_nonzero_subblocks_w); + (out_subblock_w * out_subblock_h) * (out_block_w / out_subblock_w - last_block_num_nonzero_subblocks_w); uint32_t last_block_padded_block_tiles_h_skip = - (per_core_M / out_subblock_h - last_block_num_nonzero_subblocks_h) * (per_core_N * out_subblock_h); + (out_block_h / out_subblock_h - last_block_num_nonzero_subblocks_h) * (out_block_w * out_subblock_h); CoreCoord start_core_noc = top_left_core_physical; CoreCoord end_core_noc = bottom_right_core_physical; @@ -772,7 +791,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( (std::uint32_t)end_core_noc.y, // in0_mcast_dest_noc_end_y // padding args - (std::uint32_t)per_core_M // last_block_h + (std::uint32_t)out_block_h // last_block_h }; if (fuse_op) { @@ -815,27 +834,27 @@ operation::ProgramWithCallbacks create_program_mcast_in0( if (output_idx_x == num_blocks_x - 1) { // padding args (READER) - mm_in1_sender_writer_args.push_back(last_block_w); + mm_in1_sender_writer_args.push_back(last_out_block_w); // padding args (WRITER) - mm_in1_sender_writer_args.push_back(per_core_M / out_subblock_h); + mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h); mm_in1_sender_writer_args.push_back(out_subblock_h); mm_in1_sender_writer_args.push_back(0); - mm_in1_sender_writer_args.push_back(per_core_N / out_subblock_w); // out_num_nonzero_subblocks_w + mm_in1_sender_writer_args.push_back(out_block_w / out_subblock_w); // out_num_nonzero_subblocks_w mm_in1_sender_writer_args.push_back(last_block_num_nonzero_subblocks_w); mm_in1_sender_writer_args.push_back(last_subblock_of_last_block_w); mm_in1_sender_writer_args.push_back(last_block_padded_subblock_tiles_addr_skip); mm_in1_sender_writer_args.push_back(last_block_padded_block_tiles_w_skip); } else { // padding args (READER) - mm_in1_sender_writer_args.push_back(per_core_N); + mm_in1_sender_writer_args.push_back(out_block_w); // padding args (WRITER) - mm_in1_sender_writer_args.push_back(per_core_M / out_subblock_h); + mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h); mm_in1_sender_writer_args.push_back(out_subblock_h); mm_in1_sender_writer_args.push_back(0); - mm_in1_sender_writer_args.push_back(per_core_N / out_subblock_w); // out_num_nonzero_subblocks_w - mm_in1_sender_writer_args.push_back(per_core_N / out_subblock_w); + mm_in1_sender_writer_args.push_back(out_block_w / out_subblock_w); // out_num_nonzero_subblocks_w + mm_in1_sender_writer_args.push_back(out_block_w / out_subblock_w); mm_in1_sender_writer_args.push_back(out_subblock_w); mm_in1_sender_writer_args.push_back(0); mm_in1_sender_writer_args.push_back(0); @@ -948,6 +967,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, std::optional fused_activation, @@ -991,10 +1012,19 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); - uint32_t in0_block_tiles = per_core_M * in0_block_w; + bool do_not_inplace_interm0_out_CB = output_is_sharded && (per_core_M != out_block_h); + + uint32_t in0_block_h = out_block_h; + uint32_t in1_block_w = out_block_w; + uint32_t in0_num_blocks_y = per_core_M / out_block_h; + uint32_t in1_num_blocks_x = per_core_N / out_block_w; + uint32_t out_num_blocks_x = in1_num_blocks_x; + uint32_t out_num_blocks_y = in0_num_blocks_y; + + uint32_t in0_block_tiles = in0_block_h * in0_block_w; uint32_t in0_CB_tiles = in0_block_tiles; if (in0_is_sharded) { - in0_CB_tiles = num_blocks * in0_CB_tiles * B; + in0_CB_tiles = num_blocks * per_core_M * in0_block_w * B; } else if (B * num_blocks > 1) { in0_CB_tiles = in0_CB_tiles * 2; // double buffer } @@ -1015,20 +1045,27 @@ operation::ProgramWithCallbacks create_program_mcast_in1( extract_shard_sub_blocks = true; } } + uint32_t in2_CB_tiles = in0_block_tiles; + uint32_t in2_CB_size = in2_CB_tiles * in0_single_tile_size; - uint32_t in1_block_tiles = per_core_N * in0_block_w; + uint32_t in1_block_tiles = out_block_w * in0_block_w; uint32_t in1_CB_tiles = in1_block_tiles; if (B * num_blocks > 1) { in1_CB_tiles = in1_CB_tiles * 2; // double buffer } uint32_t in1_CB_size = in1_CB_tiles * in1_single_tile_size; - uint32_t out_block_tiles = per_core_M * per_core_N; + uint32_t out_block_tiles = out_block_h * out_block_w; + uint32_t out_shard_tiles = per_core_M * per_core_N; uint32_t out_CB_tiles = out_block_tiles; // No double buffer + if (output_is_sharded) { + out_CB_tiles = out_shard_tiles; + } uint32_t out_CB_size = out_CB_tiles * output_single_tile_size; - uint32_t interm0_CB_size = out_CB_tiles * interm0_single_tile_size; + uint32_t interm0_CB_tiles = out_block_tiles; // No double buffer + uint32_t interm0_CB_size = interm0_CB_tiles * interm0_single_tile_size; - uint32_t in3_block_tiles = per_core_N; + uint32_t in3_block_tiles = out_block_w; uint32_t in3_CB_tiles = in3_block_tiles; // No double buffer uint32_t in3_CB_size = in3_CB_tiles * bias_single_tile_size; @@ -1078,21 +1115,21 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)in0_is_dram, // in0 tensor args - (std::uint32_t)1, // in0_tensor_stride_w - (std::uint32_t)K, // in0_tensor_stride_h - (std::uint32_t)in0_block_w, // in0_tensor_next_block_stride - (std::uint32_t)K * per_core_M, // in0_tensor_next_h_dim_block_stride + (std::uint32_t)1, // in0_tensor_stride_w + (std::uint32_t)K, // in0_tensor_stride_h + (std::uint32_t)in0_block_w, // in0_tensor_next_block_stride + (std::uint32_t)K * in0_block_h, // in0_tensor_next_h_dim_block_stride // in0 block args - (std::uint32_t)in0_block_w, // in0_block_w - (std::uint32_t)per_core_M, // in0_block_h - (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles + (std::uint32_t)in0_block_w, // in0_block_w + (std::uint32_t)in0_block_h, // in0_block_h + (std::uint32_t)in0_block_w * in0_block_h, // in0_block_num_tiles (std::uint32_t)extract_shard_sub_blocks, (std::uint32_t)in0_shard_width_in_tiles, (std::uint32_t)in0_shard_height_in_tiles, // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in0 mcast args (std::uint32_t)0, (std::uint32_t)0, @@ -1114,15 +1151,15 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)1, // in1_tensor_stride_w (std::uint32_t)N, // in1_tensor_stride_h (std::uint32_t)in0_block_w * N, // in1_tensor_next_block_stride - (std::uint32_t)per_core_N, // in1_tensor_next_w_dim_block_stride + (std::uint32_t)in1_block_w, // in1_tensor_next_w_dim_block_stride // in1 block args - (std::uint32_t)per_core_N, // in1_block_w - (std::uint32_t)in0_block_w, // in1_block_h - (std::uint32_t)per_core_N * in0_block_w, // in1_block_num_tiles + (std::uint32_t)in1_block_w, // in1_block_w + (std::uint32_t)in0_block_w, // in1_block_h + (std::uint32_t)in1_block_w * in0_block_w, // in1_block_num_tiles // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in1 mcast args (std::uint32_t)in1_mcast_sender_semaphore_id, (std::uint32_t)in1_mcast_receiver_semaphore_id, @@ -1139,8 +1176,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)N, // out_tensor_stride_h (std::uint32_t)out_subblock_w, // out_tensor_next_subblock_stride_w (std::uint32_t)out_subblock_h * N, // out_tensor_next_subblock_stride_h - (std::uint32_t)per_core_N, // out_tensor_next_w_dim_block_stride - (std::uint32_t)per_core_M * N, // out_tensor_next_h_dim_block_stride + (std::uint32_t)out_block_w, // out_tensor_next_w_dim_block_stride + (std::uint32_t)out_block_h * N, // out_tensor_next_h_dim_block_stride // out subblock args (std::uint32_t)out_subblock_w, // out_subblock_w (std::uint32_t)out_subblock_h, // out_subblock_h @@ -1164,11 +1201,11 @@ operation::ProgramWithCallbacks create_program_mcast_in1( // READER // in1 block args - (std::uint32_t)per_core_N * in0_block_w, // in1_block_num_tiles + (std::uint32_t)in1_block_w * in0_block_w, // in1_block_num_tiles // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - (std::uint32_t)1, // out_num_blocks_x - (std::uint32_t)1, // out_num_blocks_y + (std::uint32_t)num_blocks, // num_blocks + (std::uint32_t)out_num_blocks_x, // out_num_blocks_x + (std::uint32_t)out_num_blocks_y, // out_num_blocks_y // in1 mcast args (std::uint32_t)in1_mcast_sender_semaphore_id, (std::uint32_t)in1_mcast_receiver_semaphore_id, @@ -1181,8 +1218,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)N, // out_tensor_stride_h (std::uint32_t)out_subblock_w, // out_tensor_next_subblock_stride_w (std::uint32_t)out_subblock_h * N, // out_tensor_next_subblock_stride_h - (std::uint32_t)per_core_N, // out_tensor_next_w_dim_block_stride - (std::uint32_t)per_core_M * N, // out_tensor_next_h_dim_block_stride + (std::uint32_t)out_block_w, // out_tensor_next_w_dim_block_stride + (std::uint32_t)out_block_h * N, // out_tensor_next_h_dim_block_stride // out subblock args (std::uint32_t)out_subblock_w, // out_subblock_w (std::uint32_t)out_subblock_h, // out_subblock_h @@ -1279,11 +1316,11 @@ operation::ProgramWithCallbacks create_program_mcast_in1( // Compute kernel compile time args - uint32_t in0_num_subblocks = (per_core_M / out_subblock_h); + uint32_t in0_num_subblocks = (out_block_h / out_subblock_h); uint32_t in0_block_num_tiles = out_subblock_h * in0_block_w * in0_num_subblocks; uint32_t in0_subblock_num_tiles = out_subblock_h * in0_block_w; - uint32_t in1_num_subblocks = (per_core_N / out_subblock_w); + uint32_t in1_num_subblocks = (out_block_w / out_subblock_w); uint32_t in1_block_num_tiles = out_subblock_w * in0_block_w * in1_num_subblocks; uint32_t in1_per_core_w = out_subblock_w * in1_num_subblocks; @@ -1299,9 +1336,9 @@ operation::ProgramWithCallbacks create_program_mcast_in1( in1_block_num_tiles, // in1_block_num_tiles in1_per_core_w, // in1_per_core_w - num_blocks, // num_blocks - 1, // out_num_blocks_x - 1, // out_num_blocks_y + num_blocks, // num_blocks + out_num_blocks_x, // out_num_blocks_x + out_num_blocks_y, // out_num_blocks_y out_subblock_h, // out_subblock_h out_subblock_w, // out_subblock_w @@ -1349,7 +1386,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1( CBHandle cb_src2 = 0; if (in0_is_sharded and extract_shard_sub_blocks) { // in0_is_sharded is technically redundant tt_metal::CircularBufferConfig src2_cb_config = - tt_metal::CircularBufferConfig(in0_CB_size, {{src2_cb_index, in0_data_format}}) + tt_metal::CircularBufferConfig(in2_CB_size, {{src2_cb_index, in0_data_format}}) .set_page_size(src2_cb_index, in0_single_tile_size) .set_globally_allocated_address(*in0_buffer) .set_tile_dims(src2_cb_index, in0_tile); @@ -1359,8 +1396,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( "CB {} :: PS = {}, NP = {}, TOTAL = {}", src2_cb_index, in0_single_tile_size, - in0_CB_size / in0_single_tile_size, - in0_CB_size); + in2_CB_size / in0_single_tile_size, + in2_CB_size); } uint32_t src1_cb_index = 1; @@ -1384,7 +1421,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(0, {{output_cb_index, output_data_format}}); - if (interm0_data_format != output_data_format) { + if (do_not_inplace_interm0_out_CB || (interm0_data_format != output_data_format) || + (untilize_out && (in1_num_subblocks > 1))) { // output std::map output_cb_data_format_spec{ {output_cb_index, output_data_format}, @@ -1448,20 +1486,22 @@ operation::ProgramWithCallbacks create_program_mcast_in1( } // Parameters for last row, col, or block - uint32_t last_block_h = M % per_core_M == 0 ? per_core_M : M % per_core_M; - uint32_t last_block_w = N % per_core_N == 0 ? per_core_N : N % per_core_N; - uint32_t last_block_num_nonzero_subblocks_h = (last_block_h - 1) / out_subblock_h + 1; - uint32_t last_block_num_nonzero_subblocks_w = (last_block_w - 1) / out_subblock_w + 1; + uint32_t last_per_core_M = M % per_core_M == 0 ? per_core_M : M % per_core_M; + uint32_t last_per_core_N = N % per_core_N == 0 ? per_core_N : N % per_core_N; + uint32_t last_out_block_h = last_per_core_M % out_block_h == 0 ? out_block_h : last_per_core_M % out_block_h; + uint32_t last_out_block_w = last_per_core_N % out_block_w == 0 ? out_block_w : last_per_core_N % out_block_w; + uint32_t last_block_num_nonzero_subblocks_h = (last_out_block_h - 1) / out_subblock_h + 1; + uint32_t last_block_num_nonzero_subblocks_w = (last_out_block_w - 1) / out_subblock_w + 1; uint32_t last_subblock_of_last_block_h = - last_block_h % out_subblock_h == 0 ? out_subblock_h : last_block_h % out_subblock_h; + last_out_block_h % out_subblock_h == 0 ? out_subblock_h : last_out_block_h % out_subblock_h; uint32_t last_subblock_of_last_block_w = - last_block_w % out_subblock_w == 0 ? out_subblock_w : last_block_w % out_subblock_w; + last_out_block_w % out_subblock_w == 0 ? out_subblock_w : last_out_block_w % out_subblock_w; uint32_t last_block_padded_subblock_tiles_addr_skip = output_single_tile_size * (out_subblock_w - last_subblock_of_last_block_w); uint32_t last_block_padded_block_tiles_w_skip = - (out_subblock_w * out_subblock_h) * (per_core_N / out_subblock_w - last_block_num_nonzero_subblocks_w); + (out_subblock_w * out_subblock_h) * (out_block_w / out_subblock_w - last_block_num_nonzero_subblocks_w); uint32_t last_block_padded_block_tiles_h_skip = - (per_core_M / out_subblock_h - last_block_num_nonzero_subblocks_h) * (per_core_N * out_subblock_h); + (out_block_h / out_subblock_h - last_block_num_nonzero_subblocks_h) * (out_block_w * out_subblock_h); CoreCoord start_core_noc = bottom_right_core_physical; CoreCoord end_core_noc = top_left_core_physical; @@ -1494,13 +1534,13 @@ operation::ProgramWithCallbacks create_program_mcast_in1( (std::uint32_t)output_idx_x * per_core_N + output_idx_y * per_core_M * N, // out_tensor_start_tile_id // padding args (READER) - (std::uint32_t)per_core_N, // last_block_w + (std::uint32_t)out_block_w, // last_block_w // padding args (WRITER) - (std::uint32_t)per_core_M / out_subblock_h, + (std::uint32_t)out_block_h / out_subblock_h, (std::uint32_t)out_subblock_h, (std::uint32_t)0, - (std::uint32_t)per_core_N / out_subblock_w, - (std::uint32_t)per_core_N / out_subblock_w, + (std::uint32_t)out_block_w / out_subblock_w, + (std::uint32_t)out_block_w / out_subblock_w, (std::uint32_t)out_subblock_w, (std::uint32_t)0, (std::uint32_t)0}; @@ -1530,23 +1570,23 @@ operation::ProgramWithCallbacks create_program_mcast_in1( if (output_idx_y == num_blocks_y - 1) { // padding args (WRITER) - mm_in1_receiver_writer_args.push_back(per_core_M / out_subblock_h); + mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h); mm_in1_receiver_writer_args.push_back(last_block_num_nonzero_subblocks_h); mm_in1_receiver_writer_args.push_back(last_subblock_of_last_block_h); mm_in1_receiver_writer_args.push_back(last_block_padded_block_tiles_h_skip); - mm_in1_receiver_writer_args.push_back(per_core_N / out_subblock_w); - mm_in1_receiver_writer_args.push_back(per_core_N / out_subblock_w); + mm_in1_receiver_writer_args.push_back(out_block_w / out_subblock_w); + mm_in1_receiver_writer_args.push_back(out_block_w / out_subblock_w); mm_in1_receiver_writer_args.push_back(out_subblock_w); mm_in1_receiver_writer_args.push_back(0); mm_in1_receiver_writer_args.push_back(0); } else { // padding args (WRITER) - mm_in1_receiver_writer_args.push_back(per_core_M / out_subblock_h); - mm_in1_receiver_writer_args.push_back(per_core_M / out_subblock_h); + mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h); + mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h); mm_in1_receiver_writer_args.push_back(out_subblock_h); mm_in1_receiver_writer_args.push_back(0); - mm_in1_receiver_writer_args.push_back(per_core_N / out_subblock_w); - mm_in1_receiver_writer_args.push_back(per_core_N / out_subblock_w); + mm_in1_receiver_writer_args.push_back(out_block_w / out_subblock_w); + mm_in1_receiver_writer_args.push_back(out_block_w / out_subblock_w); mm_in1_receiver_writer_args.push_back(out_subblock_w); mm_in1_receiver_writer_args.push_back(0); mm_in1_receiver_writer_args.push_back(0); @@ -1973,6 +2013,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, @@ -2131,6 +2173,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( in0_block_w, out_subblock_h, out_subblock_w, + out_block_h, + out_block_w, per_core_M, per_core_N, fused_activation, @@ -2168,6 +2212,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( in0_block_w, out_subblock_h, out_subblock_w, + out_block_h, + out_block_w, per_core_M, per_core_N, fused_activation, @@ -2200,6 +2246,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized( uint32_t in0_block_w, uint32_t out_subblock_h, uint32_t out_subblock_w, + uint32_t out_block_h, + uint32_t out_block_w, uint32_t per_core_M, uint32_t per_core_N, bool fuse_batch, @@ -2222,6 +2270,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized( in0_block_w, out_subblock_h, out_subblock_w, + out_block_h, + out_block_w, per_core_M, per_core_N, fuse_batch, @@ -2258,6 +2308,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_helpe config.in0_block_w, config.out_subblock_h, config.out_subblock_w, + config.out_block_h, + config.out_block_w, config.per_core_M, config.per_core_N, config.fuse_batch, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp index d6a9b7b2bb6..46ac80aa248 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp @@ -129,7 +129,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( if (in0_is_sharded) { in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / in0_tile.get_tile_shape()[1]; in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / in0_tile.get_tile_shape()[0]; - in2_block_tiles = out_block_h * in0_shard_width_in_tiles; + in2_block_tiles = per_core_M * in0_shard_width_in_tiles; } uint32_t in2_CB_tiles = in2_block_tiles; @@ -363,12 +363,12 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( (std::uint32_t)in0_block_w, // in0_tensor_next_inner_dim_block_stride (std::uint32_t)K * in0_block_h, // in0_tensor_next_h_dim_block_stride // in0 block args - (std::uint32_t)in0_block_w, // in0_block_w - (std::uint32_t)in0_block_h, // in0_block_h - (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles - (std::uint32_t) false, // extract_shard_sub_blocks (not used for interleaved) - (std::uint32_t)0, // shard_width_in_tiles (not used for interleaved) - (std::uint32_t)0, // shard_height_in_tiles (not used for interleaved) + (std::uint32_t)in0_block_w, // in0_block_w + (std::uint32_t)in0_block_h, // in0_block_h + (std::uint32_t)in0_block_num_tiles, // in0_block_num_tiles + (std::uint32_t)false, // extract_shard_sub_blocks (not used for interleaved) + (std::uint32_t)in0_shard_width_in_tiles, // shard_width_in_tiles (not used for interleaved) + (std::uint32_t)in0_shard_height_in_tiles, // shard_height_in_tiles (not used for interleaved) // in0/in1 common args (std::uint32_t)num_blocks, // num_blocks (std::uint32_t)out_num_blocks_x, diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp index 7619134656e..de6d7348eb2 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp @@ -119,22 +119,43 @@ void py_module(py::module& module) { matmul_multi_core_reuse_multicast_1d_program_config .def( - py::init< - CoreCoord, - std::size_t, - std::size_t, - std::size_t, - std::size_t, - std::size_t, - bool, - std::optional, - bool, - bool>(), + py::init([](CoreCoord compute_with_storage_grid_size, + std::size_t in0_block_w, + std::size_t out_subblock_h, + std::size_t out_subblock_w, + std::optional out_block_h, + std::optional out_block_w, + std::size_t per_core_M, + std::size_t per_core_N, + bool fuse_batch, + std::optional fused_activation, + bool mcast_in0, + bool gather_in0) { + // Set out_block_h and out_block_w to defaults if they are not provided + std::size_t actual_out_block_h = out_block_h.value_or(per_core_M); + std::size_t actual_out_block_w = out_block_w.value_or(per_core_N); + + return MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size, + in0_block_w, + out_subblock_h, + out_subblock_w, + actual_out_block_h, + actual_out_block_w, + per_core_M, + per_core_N, + fuse_batch, + std::move(fused_activation), + mcast_in0, + gather_in0); + }), py::kw_only(), py::arg("compute_with_storage_grid_size"), py::arg("in0_block_w").noconvert(), py::arg("out_subblock_h").noconvert(), py::arg("out_subblock_w").noconvert(), + py::arg("out_block_h") = py::none(), + py::arg("out_block_w") = py::none(), py::arg("per_core_M").noconvert(), py::arg("per_core_N").noconvert(), py::arg("fuse_batch").noconvert(), @@ -147,6 +168,8 @@ void py_module(py::module& module) { .def_readwrite("in0_block_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::in0_block_w) .def_readwrite("out_subblock_h", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_h) .def_readwrite("out_subblock_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_subblock_w) + .def_readwrite("out_block_h", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_block_h) + .def_readwrite("out_block_w", &MatmulMultiCoreReuseMultiCast1DProgramConfig::out_block_w) .def_readwrite("per_core_M", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_M) .def_readwrite("per_core_N", &MatmulMultiCoreReuseMultiCast1DProgramConfig::per_core_N) .def_readwrite("fuse_batch", &MatmulMultiCoreReuseMultiCast1DProgramConfig::fuse_batch)