From 8cee5d432292215751a71b298603dd958deb4868 Mon Sep 17 00:00:00 2001 From: yugaoTT Date: Tue, 10 Dec 2024 21:44:29 +0000 Subject: [PATCH] #0: fix in0 height sharding case --- .../ttnn/unit_tests/operations/test_matmul.py | 42 ++++++++++++------- ...der_bmm_tile_layout_in0_sender_padding.cpp | 26 ++++++++---- ...ti_core_reuse_mcast_1d_program_factory.cpp | 10 +++-- 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index 4f7cd273c7fc..c411ab466311 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -996,8 +996,8 @@ def run_matmul_1d_multiple_output_blocks_per_core( per_core_M = m // 32 per_core_N = n // num_cores // 32 + uneven_width else: - in0_block_w = k // 32 - per_core_M = m // 32 // num_cores + uneven_width + in0_block_w = k // 32 // 2 # test exracting shards + per_core_M = m // 32 // num_cores per_core_N = n // 32 out_block_h = per_core_M // num_out_block_h out_block_w = per_core_N // num_out_block_w @@ -1017,13 +1017,21 @@ def run_matmul_1d_multiple_output_blocks_per_core( in0 = torch.randn(in0_shape).bfloat16().float() in1 = torch.randn(in1_shape).bfloat16().float() - if in_sharded and mcast_in0: - in0_memory_config = ttnn.create_sharded_memory_config( - (1, 1, m, k), - core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - ) + if in_sharded: + if mcast_in0: + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) else: in0_memory_config = ttnn.DRAM_MEMORY_CONFIG in1_memory_config = ttnn.DRAM_MEMORY_CONFIG @@ -1080,11 +1088,17 @@ def run_matmul_1d_multiple_output_blocks_per_core( fp32_dest_acc_en=False, packer_l1_acc=True, ) - if out_sharded and mcast_in0: - out_mem_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, - buffer_type=ttnn.BufferType.L1, - ) + if out_sharded: + if mcast_in0: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + else: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) else: out_mem_config = ttnn.DRAM_MEMORY_CONFIG diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp index c84a4d4825a0..0532ae005cc4 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp @@ -71,14 +71,17 @@ void kernel_main() { // In case we need to send multiple blocks per shard, in0 sharded cb is cb2 and we extract the sub-blocks to cb0 constexpr uint32_t shard_read_stride = shard_width_in_tiles * in0_single_tile_size_bytes; constexpr uint32_t shard_read_width = in0_single_tile_size_bytes * in0_block_w; + constexpr uint32_t shard_num_tiles = shard_width_in_tiles * shard_height_in_tiles; + constexpr uint32_t in0_tensor_next_h_dim_block_stride_bytes = + in0_tensor_next_h_dim_block_stride * in0_single_tile_size_bytes; - uint64_t noc_shard_read_start_addr = 0; + uint32_t noc_shard_read_start_addr = 0; if constexpr (extract_shard_sub_blocks) { constexpr uint32_t cb_id_in2 = 2; // in0 sharded cb if extract_shard_sub_blocks - noc_shard_read_start_addr = get_noc_addr(get_read_ptr(cb_id_in2)); + noc_shard_read_start_addr = get_read_ptr(cb_id_in2); } else { - cb_reserve_back(cb_id_in0, in0_block_num_tiles); - cb_push_back(cb_id_in0, in0_block_num_tiles); + cb_reserve_back(cb_id_in0, shard_num_tiles); + cb_push_back(cb_id_in0, shard_num_tiles); } #else constexpr DataFormat in0_data_format = get_dataformat(cb_id_in0); @@ -113,9 +116,15 @@ void kernel_main() { #endif for (uint32_t b = 0; b < batch; ++b) { +#ifdef IN0_SHARDED + uint32_t in0_tensor_current_h_dim_block_start_addr = noc_shard_read_start_addr; +#endif uint32_t in0_tensor_current_h_dim_block_tile_id = in0_tensor_start_tile_id; for (uint32_t bh = 0; bh < num_blocks_h_dim; ++bh) { for (uint32_t bw = 0; bw < num_blocks_w_dim; ++bw) { +#ifdef IN0_SHARDED + uint32_t in0_tensor_current_inner_dim_block_start_addr = in0_tensor_current_h_dim_block_start_addr; +#endif uint32_t in0_tensor_current_inner_dim_block_start_tile_id = in0_tensor_current_h_dim_block_tile_id; for (uint32_t block = 0; block < num_blocks_inner_dim; ++block) { if constexpr (fuse_op) { @@ -159,16 +168,16 @@ void kernel_main() { in0_start_address = l1_write_addr_in0; // copy start address of block, to be used for mcasting #endif - uint64_t noc_shard_read_addr = noc_shard_read_start_addr; - noc_shard_read_start_addr += shard_read_width; + uint64_t noc_shard_read_addr = get_noc_addr(in0_tensor_current_inner_dim_block_start_addr); - for (uint32_t i = 0; i < shard_height_in_tiles; i++) { + for (uint32_t i = 0; i < in0_block_h; i++) { noc_async_read(noc_shard_read_addr, l1_write_addr_in0, shard_read_width); l1_write_addr_in0 += shard_read_width; noc_shard_read_addr += shard_read_stride; } + in0_tensor_current_inner_dim_block_start_addr += shard_read_width; noc_async_read_barrier(); } #endif @@ -216,6 +225,9 @@ void kernel_main() { #endif } } +#ifdef IN0_SHARDED + in0_tensor_current_h_dim_block_start_addr += in0_tensor_next_h_dim_block_stride_bytes; +#endif in0_tensor_current_h_dim_block_tile_id += in0_tensor_next_h_dim_block_stride; } in0_tensor_start_tile_id += MtKt; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp index 9b9c2d476916..4e299855332e 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp @@ -1024,7 +1024,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t in0_block_tiles = in0_block_h * in0_block_w; uint32_t in0_CB_tiles = in0_block_tiles; if (in0_is_sharded) { - in0_CB_tiles = num_blocks * in0_CB_tiles * B; + in0_CB_tiles = num_blocks * per_core_M * in0_block_w * B; } else if (B * num_blocks > 1) { in0_CB_tiles = in0_CB_tiles * 2; // double buffer } @@ -1045,6 +1045,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( extract_shard_sub_blocks = true; } } + uint32_t in2_CB_tiles = in0_block_tiles; + uint32_t in2_CB_size = in2_CB_tiles * in0_single_tile_size; uint32_t in1_block_tiles = out_block_w * in0_block_w; uint32_t in1_CB_tiles = in1_block_tiles; @@ -1384,7 +1386,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1( CBHandle cb_src2 = 0; if (in0_is_sharded and extract_shard_sub_blocks) { // in0_is_sharded is technically redundant tt_metal::CircularBufferConfig src2_cb_config = - tt_metal::CircularBufferConfig(in0_CB_size, {{src2_cb_index, in0_data_format}}) + tt_metal::CircularBufferConfig(in2_CB_size, {{src2_cb_index, in0_data_format}}) .set_page_size(src2_cb_index, in0_single_tile_size) .set_globally_allocated_address(*in0_buffer) .set_tile_dims(src2_cb_index, in0_tile); @@ -1394,8 +1396,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( "CB {} :: PS = {}, NP = {}, TOTAL = {}", src2_cb_index, in0_single_tile_size, - in0_CB_size / in0_single_tile_size, - in0_CB_size); + in2_CB_size / in0_single_tile_size, + in2_CB_size); } uint32_t src1_cb_index = 1;