From bc19f9c1ddc922e8fcdd7c8b7e17868d79773ea2 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Tue, 28 May 2024 20:26:28 +0000 Subject: [PATCH] #0: Properly support trivial single core case for 1D matmuls - For mcast in1, only mcast to cores with work (similar to mcast in0) - For single core, skip receiver kernel setup - TODO: For sharded in0, K must be divisible by in0_block_w due to separate bugs: * For mcast in0, the sharded reader doesn't support turning off mcast if it's single core ** Bug here is that, we do regular mcast which is mcasting to 0 cores... * For mcast in1, the sharded reader doesn't support slicing along shard width --- .../short/matmul_user_program_config.py | 58 +++++ ...op_multi_core_reuse_mcast_1d_optimized.cpp | 201 +++++++++--------- 2 files changed, 161 insertions(+), 98 deletions(-) diff --git a/tests/ttnn/sweep_tests/sweeps/sweeps/matmul/short/matmul_user_program_config.py b/tests/ttnn/sweep_tests/sweeps/sweeps/matmul/short/matmul_user_program_config.py index d2d115411a5..b6f59d62a60 100644 --- a/tests/ttnn/sweep_tests/sweeps/sweeps/matmul/short/matmul_user_program_config.py +++ b/tests/ttnn/sweep_tests/sweeps/sweeps/matmul/short/matmul_user_program_config.py @@ -194,6 +194,64 @@ class TensorMemoryConfigs(enum.Enum): ttnn.L1_MEMORY_CONFIG, ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, ), + # Matmul 1D mcast in0 (single core) + ( + (1,), + (64, 64, 128), + False, + ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=(1, 1), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=1, + per_core_M=2, + per_core_N=4, + fuse_batch=True, + fused_activation=None, + mcast_in0=True, + ), + ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, + buffer_type=ttnn.BufferType.L1, + shard_spec=ttnn.ShardSpec( + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0))}), + (64, 64), + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ), + ttnn.L1_MEMORY_CONFIG, + ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, + ), + # Matmul 1D mcast in1 (single core) + ( + (1,), + (64, 64, 128), + False, + ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=(1, 1), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=1, + per_core_M=2, + per_core_N=4, + fuse_batch=True, + fused_activation=None, + mcast_in0=False, + ), + ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttnn.BufferType.L1, + shard_spec=ttnn.ShardSpec( + ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0))}), + (64, 64), + ttnn.ShardOrientation.ROW_MAJOR, + False, + ), + ), + ttnn.L1_MEMORY_CONFIG, + ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG, + ), # Matmul 2D mcast ( (1,), diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp index 59b092414b7..8d97bbac7ae 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp @@ -135,7 +135,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0( CoreRangeSet in0_mcast_cores_with_work_and_in_receiver_grid({}); CoreRangeSet in0_mcast_cores_without_work_and_in_receiver_grid({}); CoreRangeSet in0_mcast_cores_without_work_and_not_in_receiver_grid({}); - CoreRangeSet mcast_receivers({}); + CoreRangeSet in0_mcast_receivers({}); if (in0_is_sharded) { in0_mcast_cores_with_work_and_in_receiver_grid = all_cores_with_work; @@ -166,11 +166,13 @@ operation::ProgramWithCallbacks create_program_mcast_in0( } } else { in0_mcast_cores_with_work_and_in_receiver_grid = CoreRangeSet({CoreRange(start_core, start_core)}); - auto receiver_start_core = start_core.x != (compute_with_storage_grid_size.x - 1) - ? CoreCoord{start_core.x + 1, start_core.y} - : CoreCoord{start_core.x, start_core.y + 1}; - mcast_receivers = - num_cores_to_corerange_set(receiver_start_core, num_cores - 1, compute_with_storage_grid_size, row_major); + if (in0_mcast_receiver_num_cores > 1) { + auto receiver_start_core = start_core.x != (compute_with_storage_grid_size.x - 1) + ? CoreCoord{start_core.x + 1, start_core.y} + : CoreCoord{start_core.x, start_core.y + 1}; + in0_mcast_receivers = num_cores_to_corerange_set( + receiver_start_core, num_cores - 1, compute_with_storage_grid_size, row_major); + } } // Mcast args @@ -379,11 +381,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0( } KernelHandle mm_kernel_in0_receiver_id = 0; - if (!in0_is_sharded and mcast_receivers.num_cores() > 0) { + if (!in0_is_sharded and in0_mcast_receivers.num_cores() > 0) { mm_kernel_in0_receiver_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp", - mcast_receivers, + in0_mcast_receivers, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, .noc = in0_noc, @@ -855,26 +857,27 @@ operation::ProgramWithCallbacks create_program_mcast_in1( CoreCoord start_core = {0, 0}; uint32_t start_core_x = start_core.x; uint32_t start_core_y = start_core.y; - uint32_t num_cores_c = compute_with_storage_grid_size.x; - uint32_t num_cores_r = compute_with_storage_grid_size.y; uint32_t num_blocks_y = (M - 1) / per_core_M + 1; uint32_t num_blocks_x = (N - 1) / per_core_N + 1; uint32_t num_blocks_total = num_blocks_y * num_blocks_x; uint32_t num_cores = num_blocks_total; - uint32_t mcast_num_cores = num_cores_c * num_cores_r; // Exclude Sender constexpr bool row_major = true; CoreRangeSet all_cores = num_cores_to_corerange_set(start_core, num_cores, compute_with_storage_grid_size, row_major); + CoreRange in1_mcast_receiver_cores_bounding_box = all_cores.bounding_box(); + uint32_t in1_mcast_receiver_num_cores = in1_mcast_receiver_cores_bounding_box.size(); // always mcast to full grid - CoreRange mcast_sender(start_core, start_core); - - auto receiver_start_core = start_core.x != (compute_with_storage_grid_size.x - 1) - ? CoreCoord{start_core.x + 1, start_core.y} - : CoreCoord{start_core.x, start_core.y + 1}; - CoreRangeSet mcast_receivers = - num_cores_to_corerange_set(receiver_start_core, num_cores - 1, compute_with_storage_grid_size, row_major); + CoreRange in1_mcast_sender(start_core, start_core); + CoreRangeSet in1_mcast_receivers({}); + if (in1_mcast_receiver_num_cores > 1) { + auto receiver_start_core = start_core.x != (compute_with_storage_grid_size.x - 1) + ? CoreCoord{start_core.x + 1, start_core.y} + : CoreCoord{start_core.x, start_core.y + 1}; + in1_mcast_receivers = + num_cores_to_corerange_set(receiver_start_core, num_cores - 1, compute_with_storage_grid_size, row_major); + } // Mcast args auto in1_mcast_sender_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); @@ -882,9 +885,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t in3_mcast_sender_semaphore = 0; uint32_t in3_mcast_receiver_semaphore = 0; - CoreCoord top_left_core = {(std::size_t)start_core_x, (std::size_t)start_core_y}; - CoreCoord bottom_right_core = { - (std::size_t)start_core_x + num_cores_c - 1, (std::size_t)start_core_y + num_cores_r - 1}; + CoreCoord top_left_core = in1_mcast_receiver_cores_bounding_box.start; + CoreCoord bottom_right_core = in1_mcast_receiver_cores_bounding_box.end; auto top_left_core_physical = device->worker_core_from_logical_core(top_left_core); auto bottom_right_core_physical = device->worker_core_from_logical_core(bottom_right_core); @@ -937,8 +939,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( // in1 mcast args (std::uint32_t)in1_mcast_sender_semaphore, (std::uint32_t)in1_mcast_receiver_semaphore, - (std::uint32_t)num_cores - 1, // in1_mcast_num_dests - (std::uint32_t)mcast_num_cores - 1, // in1_mcast_num_cores + (std::uint32_t)num_cores - 1, // in1_mcast_num_dests + (std::uint32_t)in1_mcast_receiver_num_cores - 1, // in1_mcast_num_cores // batch args (std::uint32_t)K * N, // KtNt (std::uint32_t)B, // batch @@ -1027,6 +1029,10 @@ operation::ProgramWithCallbacks create_program_mcast_in1( mm_kernel_in0_sender_defines["SKIP_MCAST"] = "1"; + if (in1_mcast_receiver_num_cores == 1) { + mm_kernel_in1_sender_writer_defines["SKIP_MCAST"] = "1"; + } + // in1 is the reader of weights/output writer, and we choose to make it use the optimized reader noc tt_metal::NOC in0_noc = detail::GetPreferredNOCForDRAMWrite(device->arch()); tt_metal::NOC in1_noc = detail::GetPreferredNOCForDRAMRead(device->arch()); @@ -1044,23 +1050,26 @@ operation::ProgramWithCallbacks create_program_mcast_in1( auto mm_kernel_in1_sender_writer_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp", - mcast_sender, + in1_mcast_sender, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = in1_noc, .compile_args = in1_sender_writer_compile_time_args, .defines = mm_kernel_in1_sender_writer_defines}); - auto mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel( - program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/" - "reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", - mcast_receivers, - tt_metal::DataMovementConfig{ - .processor = tt_metal::DataMovementProcessor::RISCV_0, - .noc = in1_noc, - .compile_args = in1_receiver_writer_compile_time_args, - .defines = mm_kernel_in1_receiver_writer_defines}); + KernelHandle mm_kernel_in1_receiver_writer_id = 0; + if (in1_mcast_receivers.num_cores() > 0) { + mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/" + "reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", + in1_mcast_receivers, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = in1_noc, + .compile_args = in1_receiver_writer_compile_time_args, + .defines = mm_kernel_in1_receiver_writer_defines}); + } // Compute kernel compile time args @@ -1541,71 +1550,67 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( //////////////////////////////////////////////////////////////////////////// // Application Setup //////////////////////////////////////////////////////////////////////////// - if (compute_with_storage_grid_size.x > 1 or compute_with_storage_grid_size.y > 1) { - if (mcast_in0) { - return reuse_mcast_1d_optimized_helpers::create_program_mcast_in0( - a, - device, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode, - packer_l1_acc, - compute_with_storage_grid_size, - B, - Mt, - Nt, - Kt, - bcast_batch, - in0_block_w, - out_subblock_h, - out_subblock_w, - per_core_M, - per_core_N, - fused_activation, - in0_buffer, - in1_buffer, - bias_buffer, - out_buffer, - in0_data_format, - in1_data_format, - bias_data_format, - output_data_format, - a.memory_config().is_sharded(), - output.memory_config().is_sharded(), - untilize_out); - } else { - return reuse_mcast_1d_optimized_helpers::create_program_mcast_in1( - device, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode, - packer_l1_acc, - compute_with_storage_grid_size, - B, - Mt, - Nt, - Kt, - bcast_batch, - in0_block_w, - out_subblock_h, - out_subblock_w, - per_core_M, - per_core_N, - fused_activation, - in0_buffer, - in1_buffer, - bias_buffer, - out_buffer, - in0_data_format, - in1_data_format, - bias_data_format, - output_data_format, - a.memory_config().is_sharded(), - output.memory_config().is_sharded(), - untilize_out); - } + if (mcast_in0) { + return reuse_mcast_1d_optimized_helpers::create_program_mcast_in0( + a, + device, + math_fidelity, + fp32_dest_acc_en, + math_approx_mode, + packer_l1_acc, + compute_with_storage_grid_size, + B, + Mt, + Nt, + Kt, + bcast_batch, + in0_block_w, + out_subblock_h, + out_subblock_w, + per_core_M, + per_core_N, + fused_activation, + in0_buffer, + in1_buffer, + bias_buffer, + out_buffer, + in0_data_format, + in1_data_format, + bias_data_format, + output_data_format, + a.memory_config().is_sharded(), + output.memory_config().is_sharded(), + untilize_out); } else { - TT_FATAL(false, "Grid is invalid for mcast matmul!"); + return reuse_mcast_1d_optimized_helpers::create_program_mcast_in1( + device, + math_fidelity, + fp32_dest_acc_en, + math_approx_mode, + packer_l1_acc, + compute_with_storage_grid_size, + B, + Mt, + Nt, + Kt, + bcast_batch, + in0_block_w, + out_subblock_h, + out_subblock_w, + per_core_M, + per_core_N, + fused_activation, + in0_buffer, + in1_buffer, + bias_buffer, + out_buffer, + in0_data_format, + in1_data_format, + bias_data_format, + output_data_format, + a.memory_config().is_sharded(), + output.memory_config().is_sharded(), + untilize_out); } }