diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py index 61e674d98c4..6eb3991d34d 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py @@ -11,7 +11,14 @@ comp_equal, comp_pcc, ) -from models.utility_functions import is_wormhole_b0, is_wormhole_b0, is_blackhole, skip_for_blackhole +from models.utility_functions import ( + is_wormhole_b0, + is_wormhole_b0, + is_blackhole, + skip_for_blackhole, + skip_for_grayskull, + run_for_wormhole_b0, +) from loguru import logger from models.utility_functions import torch2tt_tensor, tt2torch_tensor, pad_by_zero, roundup32 @@ -682,8 +689,7 @@ def test_bcast_hw(device, num_cores, in0_height_sharded, out_height_sharded, in_ out_mem_config = ttnn.DRAM_MEMORY_CONFIG if in0_height_sharded: - compute_with_storage_grid_size = device.compute_with_storage_grid_size() - device_grid_size = ttnn.CoreGrid(y=compute_with_storage_grid_size.y, x=compute_with_storage_grid_size.x) + device_grid_size = ttnn.CoreGrid(y=8, x=8) if num_cores == 64 else ttnn.CoreGrid(y=1, x=1) tt_in0_height_sharded = ttnn.to_memory_config( tt_in0_dram, @@ -2418,3 +2424,137 @@ def test_interleaved_2_sharded_DRAM(device, dtype, y): ) yt = ttnn.interleaved_to_sharded(xt, shard_grid, (y // 8, 18 * 32), shard_scheme, ttnn.ShardOrientation.ROW_MAJOR) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize( + "seq_len", + (32,), +) +def test_llama_mlp_width_sharded_to_interleaved_pcc_err(device, seq_len, use_program_cache): + dim_in = 4096 + dim_hidden = int(3.5 * dim_in / 4) # 3584 + dim_out = dim_in + # Create random input tensor + input_tensor = torch.randn(1, 1, int(seq_len), dim_in) + # Create random weight matrices + w1 = torch.randn(dim_hidden, dim_in) + w2 = torch.randn(dim_out, dim_hidden) + # Pytorch reference implementation + ## First linear layer + hidden = torch.matmul(input_tensor, w1.t()) + ## Second linear layer + output_w2 = torch.matmul(hidden, w2.t()) + ## Add residual connection + reference_output = output_w2 + input_tensor + # TTNN implementation + input_mem_config = ttnn.create_sharded_memory_config( + ( + 32, + 128, + ), # Shard shape: [32, 128] -> 1 shard per core + ttnn.CoreGrid(x=8, y=4), + ttnn.ShardStrategy.WIDTH, + ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + w1_out_reshard_mem_config = ttnn.create_sharded_memory_config( + ( + 32, + 128, + ), # Shard shape: [32, 128] -> 1 shard per core + ttnn.CoreGrid(x=7, y=4), + ttnn.ShardStrategy.WIDTH, + ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + dram_core_range_set = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(11, 0), + ), + } + ) + w1_w3_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ttnn.ShardSpec(dram_core_range_set, (4096, 320), ttnn.ShardOrientation.ROW_MAJOR, False), + ) + w2_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.DRAM, + ttnn.ShardSpec(dram_core_range_set, (3584, 352), ttnn.ShardOrientation.ROW_MAJOR, False), + ) + pc_1 = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( + in0_block_w=4, + per_core_M=1, + per_core_N=4, + fused_activation=None, + ) + pc_2 = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( + in0_block_w=4, + per_core_M=1, + per_core_N=5, + fused_activation=None, + ) + ## convert input tensor and weights to TTNN tensors + tt_input = ttnn.from_torch( + input_tensor, + device=device, + dtype=ttnn.bfloat8_b, + memory_config=input_mem_config, + layout=ttnn.TILE_LAYOUT, + ) + as_sharded_tensor = lambda w, type, dim, mem_config: ttnn.as_tensor( + w, # Grab only the wX part of the name + dtype=type, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=mem_config, + ) + # Sharded weights + tt_w1 = as_sharded_tensor(w1.t(), ttnn.bfloat8_b, dim=-1, mem_config=w1_w3_mem_config) + tt_w2 = as_sharded_tensor(w2.t(), ttnn.bfloat8_b, dim=-2, mem_config=w2_mem_config) + ## MLP takes replicated inputs and produces fractured outputs + logger.info(f"tt_input shape: {tt_input.shape}") + logger.info(f"tt_input memory config: {tt_input.memory_config()}") + w1_out = ttnn.linear( + tt_input, + tt_w1, + core_grid=None, + dtype=ttnn.bfloat16, + program_config=pc_1, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, + ) + logger.info(f"w1_out shape: {w1_out.shape}") + logger.info(f"w1_out memory config: {w1_out.memory_config()}") + w1_out = ttnn.reshard(w1_out, w1_out_reshard_mem_config) + w2_out = ttnn.linear( + w1_out, + tt_w2, + core_grid=None, + dtype=ttnn.bfloat16, + program_config=pc_2, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, + ) + logger.info(f"w2_out shape: {w2_out.shape}") + logger.info(f"w2_out memory config: {w2_out.memory_config()}") + w2_out = ttnn.sharded_to_interleaved(w2_out, ttnn.L1_MEMORY_CONFIG) + tt_input = ttnn.sharded_to_interleaved(tt_input, ttnn.L1_MEMORY_CONFIG) + + # ## Add residual connection + tt_input_torch = ttnn.to_torch(tt_input) + tt_w2_out_torch = ttnn.to_torch(w2_out) + tt_output = ttnn.add(tt_input, w2_out) + tt_output_torch = ttnn.to_torch(tt_output) + pcc_required = 0.99 + passing_w2_out, pcc_message_w2_out = comp_pcc(output_w2, tt_w2_out_torch) + passing_input, pcc_message_input = comp_pcc(input_tensor, tt_input_torch) + passing, pcc_message = comp_pcc(reference_output, tt_output_torch, pcc_required) + logger.info(f"w2_out PCC: {pcc_message_w2_out}") + logger.info(f"input PCC: {pcc_message_input}") + logger.info(f"residual PCC: {pcc_message}") + assert passing_w2_out + assert passing_input + assert passing diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp index 2cb58883bf1..b6fc9f0e5c8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp @@ -18,7 +18,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( tt_metal::Program program{}; uint32_t num_units, num_units_per_shard, input_unit_size, output_unit_size, num_units_per_shard_width, - num_units_per_shard_height, num_units_offset, num_units_per_row, num_units_per_shard_height_last, + num_units_per_shard_height, num_units_offset, num_units_per_row, num_units_height, num_units_per_shard_height_last, num_units_per_shard_width_last; tt_metal::Device* device = input.device(); @@ -30,7 +30,12 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( auto shard_strategy = input.memory_config().memory_layout; bool rm_orientation = shard_spec.orientation == ShardOrientation::ROW_MAJOR; - CoreCoord end_core = (*shard_spec.grid.ranges().rbegin()).end_coord; + auto& all_cores = shard_spec.grid; + uint32_t num_cores = all_cores.num_cores(); + uint32_t num_cores_unpadded = num_cores; + const auto cores = corerange_to_cores(all_cores, std::nullopt, rm_orientation); + + CoreCoord end_core = cores[num_cores - 1]; if (output.get_layout() == Layout::TILE) { num_units = input.volume() / TILE_HW; input_unit_size = tt_metal::detail::TileSize(input_cb_data_format); @@ -40,7 +45,7 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( num_units_per_shard = num_units_per_shard_height * num_units_per_shard_width; num_units_per_row = output.get_legacy_shape()[-1] / TILE_WIDTH; num_units_offset = num_units_per_row; - uint32_t num_units_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT / num_slices; + num_units_height = output.volume() / output.get_legacy_shape()[-1] / TILE_HEIGHT / num_slices; num_units_per_shard_height_last = num_units_per_shard_height - (round_up(num_units_height, num_units_per_shard_height) - num_units_height); num_units_per_shard_width_last = @@ -55,17 +60,26 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( num_units_per_shard = num_units_per_shard_height * num_units_per_shard_width; num_units_per_row = output.get_legacy_shape()[-1] * output.element_size(); num_units_offset = 1; - uint32_t num_units_height = input.volume() / input.get_legacy_shape()[-1]; + num_units_height = input.volume() / input.get_legacy_shape()[-1]; num_units_per_shard_height_last = num_units_per_shard_height - (round_up(num_units_height, num_units_per_shard_height) - num_units_height); num_units_per_shard_width_last = output_unit_size - (round_up(num_units_per_row, output_unit_size) - num_units_per_row); } - bool convert_df = input_cb_data_format != output_cb_data_format; + // re-calculate end_core in the case shard grid is larger than used grid + if (shard_strategy == TensorMemoryLayout::HEIGHT_SHARDED) { + num_cores_unpadded = div_up(num_units_height, num_units_per_shard_height); + } else if (shard_strategy == TensorMemoryLayout::WIDTH_SHARDED) { + if (output.get_layout() == Layout::TILE) { + num_cores_unpadded = div_up(num_units_per_row, num_units_per_shard_width); + } else { + num_cores_unpadded = div_up(num_units_per_row, output_unit_size); + } + } + TT_ASSERT(num_cores_unpadded == num_cores, "number of cores {} in shard spec not equal to the unpadded number of cores {}", num_cores_unpadded, num_cores); - auto& all_cores = shard_spec.grid; - uint32_t num_cores = all_cores.num_cores(); + bool convert_df = input_cb_data_format != output_cb_data_format; uint32_t src0_cb_index = CB::c_in0; uint32_t out_cb_index = src0_cb_index; @@ -141,13 +155,12 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( uint32_t curr_idx_h = 0; uint32_t curr_idx_w = 0; - const auto cores = corerange_to_cores(all_cores, std::nullopt, rm_orientation); uint32_t padded_offset_bytes; for (const auto& core : cores) { + uint32_t shard_height = num_units_per_shard_height; + uint32_t shard_width = input.get_layout() == Layout::TILE ? num_units_per_shard_width : output_unit_size; if (input.get_layout() == Layout::TILE) { - uint32_t shard_height = num_units_per_shard_height; - uint32_t shard_width = num_units_per_shard_width; if (shard_strategy == TensorMemoryLayout::HEIGHT_SHARDED) { if (core.x == end_core.x && core.y == end_core.y) { shard_height = num_units_per_shard_height_last; @@ -192,8 +205,6 @@ operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( curr_idx_h += num_units_per_row * num_units_per_shard_height; } } else { - uint32_t shard_height = num_units_per_shard_height; - uint32_t shard_width = output_unit_size; if (shard_strategy == TensorMemoryLayout::HEIGHT_SHARDED) { if (core.x == end_core.x && core.y == end_core.y) { shard_height = num_units_per_shard_height_last; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 622ebad74c0..7fdce36eebc 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1182,6 +1182,7 @@ void Matmul::validate( // No padding TT_FATAL(M == per_core_M, "Error"); + TT_FATAL(M == 1, "currently only support in0 tensor height of tile height"); TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); TT_FATAL(K % program_config.in0_block_w == 0, "Error"); TT_FATAL((shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Error"); @@ -1406,7 +1407,7 @@ std::vector Matmul::create_output_tensors(const std::vector& inp } else if constexpr (std::is_same_v< ProgramConfigType, MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig>) { - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1]; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; auto input_tensor_b_shape = input_tensor_b.get_legacy_shape(); @@ -1415,7 +1416,14 @@ std::vector Matmul::create_output_tensors(const std::vector& inp TT_FATAL(per_core_N % tile_width_ratio == 0, "per_core_N must be divisible by override output tile width"); - CoreRangeSet all_cores = input_tensor_a.shard_spec().value().grid; + uint32_t num_blocks_y = (M - 1) / per_core_M + 1; + uint32_t num_blocks_x = (N - 1) / per_core_N + 1; + uint32_t num_blocks_total = num_blocks_y * num_blocks_x; + uint32_t num_cores = num_blocks_x * num_blocks_y; + auto end_core = input_tensor_a.shard_spec()->grid.bounding_box().end_coord; + auto grid_size = CoreCoord{end_core.x + 1, end_core.y + 1}; + CoreRangeSet all_cores = + num_cores_to_corerangeset(num_cores, grid_size, true); ShardSpec shard_spec = ShardSpec{ all_cores, {per_core_M * in0_tile_shape[0], per_core_N * in1_tile_shape[1]}, ShardOrientation::ROW_MAJOR}; auto mem_config = this->output_mem_config; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp index e24185ed8b9..72203d67390 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp @@ -484,23 +484,23 @@ operation::ProgramWithCallbacks create_program_dram_sharded( log_debug("all_worker_cores_ordered: {}", core); } - uint32_t per_core_N = (N + num_dram_banks - 1) / num_dram_banks; - uint32_t per_core_N_unpad = per_core_N; - auto subblock_hw = bmm_op_utils::get_matmul_subblock_params(per_core_M, per_core_N, false, false, fp32_dest_acc_en); + uint32_t per_core_N_compute = (N + num_dram_banks - 1) / num_dram_banks; + uint32_t per_core_N_in1_sender = per_core_N_compute; + auto subblock_hw = bmm_op_utils::get_matmul_subblock_params(per_core_M, per_core_N_compute, false, false, fp32_dest_acc_en); auto out_subblock_h = std::get<0>(subblock_hw); auto out_subblock_w = std::get<1>(subblock_hw); uint32_t max_subblock_w = fp32_dest_acc_en ? 4 : 8; - // it is bad for compute, pad per_core_N + // it is bad for compute, pad per_core_N_compute if (out_subblock_h == 1 and out_subblock_w < max_subblock_w) { - uint32_t num_subblock_w_per_core_N = per_core_N / out_subblock_w; + uint32_t num_subblock_w_per_core_N = per_core_N_compute / out_subblock_w; uint32_t num_iter = max_subblock_w - out_subblock_w; uint32_t new_out_subblock_w = out_subblock_w; uint32_t preferred_out_subblock_w = out_subblock_w; for (uint32_t i = 0; i < num_iter; ++i) { new_out_subblock_w += 1; - uint32_t new_num_subblock_w_per_core_N = (per_core_N + new_out_subblock_w - 1) / new_out_subblock_w; + uint32_t new_num_subblock_w_per_core_N = (per_core_N_compute + new_out_subblock_w - 1) / new_out_subblock_w; if (new_num_subblock_w_per_core_N < num_subblock_w_per_core_N) { num_subblock_w_per_core_N = new_num_subblock_w_per_core_N; @@ -508,10 +508,10 @@ operation::ProgramWithCallbacks create_program_dram_sharded( } } out_subblock_w = preferred_out_subblock_w; - per_core_N = out_subblock_w * num_subblock_w_per_core_N; + per_core_N_compute = out_subblock_w * num_subblock_w_per_core_N; } - log_debug("per_core_M: {}, per_core_N: {}", per_core_M, per_core_N); + log_debug("per_core_M: {}, per_core_N_compute: {}, per_core_N_in1_sender: {}", per_core_M, per_core_N_compute, per_core_N_in1_sender); log_debug("out_subblock_h: {}, out_subblock_w: {}", out_subblock_h, out_subblock_w); uint32_t num_blocks = K / in0_block_w; @@ -537,14 +537,14 @@ operation::ProgramWithCallbacks create_program_dram_sharded( in0_CB_tiles = in0_CB_tiles * 2; // double buffer } uint32_t in0_CB_size = in0_CB_tiles * in0_single_tile_size; - uint32_t in1_block_tiles = per_core_N_unpad * in0_block_w; + uint32_t in1_block_tiles = per_core_N_in1_sender * in0_block_w; uint32_t in1_CB_tiles = in1_block_tiles; if (B * num_blocks > 1) { in1_CB_tiles = in1_CB_tiles * 3; // tripple buffer } uint32_t in1_CB_size = in1_CB_tiles * in1_single_tile_size; - uint32_t out_block_tiles = per_core_M * per_core_N; + uint32_t out_block_tiles = per_core_M * per_core_N_compute; uint32_t out_CB_tiles = out_block_tiles; // No double buffer uint32_t out_CB_size = out_CB_tiles * output_single_tile_size; uint32_t interm0_CB_size = out_CB_tiles * interm0_single_tile_size; @@ -559,7 +559,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded( uint32_t in2_CB_tiles = in2_block_tiles; uint32_t in2_CB_size = in2_CB_tiles * in0_single_tile_size; - uint32_t in3_block_tiles = per_core_N_unpad; + uint32_t in3_block_tiles = per_core_N_in1_sender; uint32_t in3_CB_tiles = in3_block_tiles; // No double buffer uint32_t in3_CB_size = in3_CB_tiles * bias_single_tile_size; @@ -691,12 +691,12 @@ operation::ProgramWithCallbacks create_program_dram_sharded( (std::uint32_t)in1_buffer_page_size, (std::uint32_t)in1_buffer_num_pages, // in1 block args - (std::uint32_t)per_core_N_unpad, // in1_block_w - (std::uint32_t)per_core_N_unpad * in0_block_w, // in1_block_num_tiles + (std::uint32_t)per_core_N_in1_sender, // in1_block_w + (std::uint32_t)per_core_N_in1_sender * in0_block_w, // in1_block_num_tiles // in0/in1 common args (std::uint32_t)num_blocks, // num_blocks (std::uint32_t)out_block_tiles, // out_block_num_tiles - (std::uint32_t)per_core_N * output_single_tile_size, // out_tensor_stride_w_bytes + (std::uint32_t)per_core_N_compute * output_single_tile_size, // out_tensor_stride_w_bytes (std::uint32_t)per_core_N_storage * output_single_tile_size, // out_reshard_tensor_stride_w_bytes (std::uint32_t)per_core_M}; if (bias_buffer != nullptr) { @@ -767,8 +767,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded( // Compute kernel compile time args uint32_t in0_subblock_num_tiles = out_subblock_h * in0_block_w; - uint32_t in1_num_subblocks = (per_core_N / out_subblock_w); - uint32_t in1_per_core_w = per_core_N_unpad; + uint32_t in1_num_subblocks = (per_core_N_compute / out_subblock_w); + uint32_t in1_per_core_w = per_core_N_in1_sender; uint32_t out_subblock_num_tiles = out_subblock_h * out_subblock_w; std::vector compute_kernel_args = { @@ -932,22 +932,6 @@ operation::ProgramWithCallbacks create_program_dram_sharded( in3_CB_size); } - // Parameters for last row, col, or block - uint32_t last_block_h = M % per_core_M == 0 ? per_core_M : M % per_core_M; - uint32_t last_block_w = N % per_core_N == 0 ? per_core_N : N % per_core_N; - uint32_t last_block_num_nonzero_subblocks_h = (last_block_h - 1) / out_subblock_h + 1; - uint32_t last_block_num_nonzero_subblocks_w = (last_block_w - 1) / out_subblock_w + 1; - uint32_t last_subblock_of_last_block_h = - last_block_h % out_subblock_h == 0 ? out_subblock_h : last_block_h % out_subblock_h; - uint32_t last_subblock_of_last_block_w = - last_block_w % out_subblock_w == 0 ? out_subblock_w : last_block_w % out_subblock_w; - uint32_t last_block_padded_subblock_tiles_addr_skip = - output_single_tile_size * (out_subblock_w - last_subblock_of_last_block_w); - uint32_t last_block_padded_block_tiles_w_skip = - (out_subblock_w * out_subblock_h) * (per_core_N / out_subblock_w - last_block_num_nonzero_subblocks_w); - uint32_t last_block_padded_block_tiles_h_skip = - (per_core_M / out_subblock_h - last_block_num_nonzero_subblocks_h) * (per_core_N * out_subblock_h); - std::vector reader_kernel_ids; std::vector writer_kernel_ids; @@ -1063,6 +1047,11 @@ operation::ProgramWithCallbacks create_program_dram_sharded( } } + uint32_t num_cores_written_back = (N + per_core_N_storage - 1) / per_core_N_storage; + uint32_t expected_max_total_width = num_cores_written_back * per_core_N_storage; + tt::log_debug("per_core_N_storage: {}",per_core_N_storage); + tt::log_debug("num_cores_written_back: {}",num_cores_written_back); + uint32_t total_tensor_width_written_back = 0; for (uint32_t i = 0; i < all_worker_cores_ordered.size(); ++i) { auto core = all_worker_cores_ordered[i]; @@ -1092,19 +1081,26 @@ operation::ProgramWithCallbacks create_program_dram_sharded( bank_id = (bank_id + 1) % num_dram_banks; - if (per_core_N_unpad < per_core_N_storage) { - if (curr_storage_core_idx < all_storage_cores_vec.size()) { + if (per_core_N_in1_sender < per_core_N_storage) { + if (curr_storage_core_idx < num_cores_written_back) { uint32_t remaining_per_core_N_storage = (per_core_N_storage - per_core_N_storage_curr_stride); uint32_t per_core_N_reshard_1 = - (remaining_per_core_N_storage > per_core_N_unpad) ? per_core_N_unpad : remaining_per_core_N_storage; - uint32_t per_core_N_reshard_2 = per_core_N_unpad - per_core_N_reshard_1; + (remaining_per_core_N_storage > per_core_N_in1_sender) ? per_core_N_in1_sender : remaining_per_core_N_storage; + uint32_t per_core_N_reshard_2 = per_core_N_in1_sender - per_core_N_reshard_1; - if (per_core_N_reshard_2 != 0 and (curr_storage_core_idx + 1) < all_storage_cores_vec.size()) { + if (per_core_N_reshard_2 != 0 and (curr_storage_core_idx + 1) < num_cores_written_back) { mm_in1_sender_writer_args.push_back(2); } else { mm_in1_sender_writer_args.push_back(1); } + log_debug( + "curr worker core: {}, send back: {} tiles to storage core: {}, coord: {}", + i, + per_core_N_reshard_1, + curr_storage_core_idx, + mcast_senders_coords[curr_storage_core_idx]); + mm_in1_sender_writer_args.push_back( per_core_N_storage_curr_stride * output_single_tile_size); // reshard_tensor_start_offset mm_in1_sender_writer_args.push_back( @@ -1114,33 +1110,45 @@ operation::ProgramWithCallbacks create_program_dram_sharded( mm_in1_sender_writer_args.push_back( in0_mcast_sender_noc_y[curr_storage_core_idx]); // in0_mcast_sender_noc_y - if (per_core_N_reshard_2 != 0 and (curr_storage_core_idx + 1) < all_storage_cores_vec.size()) { + total_tensor_width_written_back += per_core_N_reshard_1; + + if (per_core_N_reshard_2 != 0 and (curr_storage_core_idx + 1) < num_cores_written_back) { + log_debug( + "curr worker core: {}, send back: {} tiles to storage core: {}, coord: {}", + i, + per_core_N_reshard_2, + curr_storage_core_idx + 1, + mcast_senders_coords[curr_storage_core_idx + 1]); + mm_in1_sender_writer_args.push_back( per_core_N_reshard_2 * output_single_tile_size); // per_core_N_reshard_bytes_2 mm_in1_sender_writer_args.push_back( in0_mcast_sender_noc_x[curr_storage_core_idx + 1]); // in0_mcast_sender_noc_x mm_in1_sender_writer_args.push_back( in0_mcast_sender_noc_y[curr_storage_core_idx + 1]); // in0_mcast_sender_noc_y + + total_tensor_width_written_back += per_core_N_reshard_2; } - curr_storage_core_idx += (per_core_N_storage_curr_stride + per_core_N_unpad) / per_core_N_storage; + curr_storage_core_idx += (per_core_N_storage_curr_stride + per_core_N_in1_sender) / per_core_N_storage; per_core_N_storage_curr_stride = - (per_core_N_storage_curr_stride + per_core_N_unpad) % per_core_N_storage; + (per_core_N_storage_curr_stride + per_core_N_in1_sender) % per_core_N_storage; } } else { - uint32_t num_iter = 0; + uint32_t num_cores_write_back = 0; - if (curr_storage_core < all_storage_cores_vec.size()) { - num_iter++; + if (curr_storage_core < num_cores_written_back) { + num_cores_write_back++; + + worker_core_stride = per_core_N_storage - storage_core_stride; log_debug( - "curr worker core: {}, send back to storage core: {}, coord: {}", + "curr worker core: {}, send back: {} tiles to storage core: {}, coord: {}", curr_worker_core, + worker_core_stride, curr_storage_core, mcast_senders_coords[curr_storage_core]); - worker_core_stride = per_core_N_storage - storage_core_stride; - mm_in1_sender_writer_args.push_back( storage_core_stride * output_single_tile_size); // reshard_tensor_start_offset mm_in1_sender_writer_args.push_back( @@ -1153,45 +1161,54 @@ operation::ProgramWithCallbacks create_program_dram_sharded( curr_storage_core += (storage_core_stride + worker_core_stride) / per_core_N_storage; storage_core_stride = (storage_core_stride + worker_core_stride) % per_core_N_storage; - if (worker_core_stride >= per_core_N_unpad) { + if (worker_core_stride >= per_core_N_in1_sender) { curr_worker_core += 1; } - while (curr_worker_core <= i and curr_storage_core < all_storage_cores_vec.size()) { - num_iter++; + total_tensor_width_written_back += worker_core_stride; + + while (curr_worker_core <= i and curr_storage_core < num_cores_written_back) { + num_cores_write_back++; + + bool increment_worker_core = (worker_core_stride + per_core_N_storage) >= per_core_N_in1_sender; + uint32_t current_worker_stride_total = increment_worker_core ? per_core_N_in1_sender : worker_core_stride + per_core_N_storage; + uint32_t current_worker_write_back_tiles = current_worker_stride_total - worker_core_stride; log_debug( - "curr worker core: {}, send back to storage core: {}, coord: {}", + "curr worker core: {}, send back: {} tiles to storage core: {}, coord: {}", curr_worker_core, + current_worker_write_back_tiles, curr_storage_core, mcast_senders_coords[curr_storage_core]); - uint32_t stride = worker_core_stride + per_core_N_storage; - if (stride >= per_core_N_unpad) { - stride = per_core_N_unpad; + if (increment_worker_core) { curr_worker_core += 1; } mm_in1_sender_writer_args.push_back( - (stride - worker_core_stride) * output_single_tile_size); // per_core_N_reshard + current_worker_write_back_tiles * output_single_tile_size); // per_core_N_reshard mm_in1_sender_writer_args.push_back( in0_mcast_sender_noc_x[curr_storage_core]); // in0_mcast_sender_noc_x mm_in1_sender_writer_args.push_back( in0_mcast_sender_noc_y[curr_storage_core]); // in0_mcast_sender_noc_y - storage_core_stride = (stride - worker_core_stride) % per_core_N_storage; - curr_storage_core += (stride - worker_core_stride) / per_core_N_storage; - worker_core_stride = stride; + total_tensor_width_written_back += current_worker_write_back_tiles; + + storage_core_stride = current_worker_write_back_tiles % per_core_N_storage; + curr_storage_core += current_worker_write_back_tiles / per_core_N_storage; + worker_core_stride = current_worker_stride_total; } } - mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 5, num_iter); + mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 5, num_cores_write_back); } tt_metal::SetRuntimeArgs(program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); writer_kernel_ids.push_back(mm_kernel_in1_sender_writer_id); } + TT_ASSERT(total_tensor_width_written_back <= expected_max_total_width, "more datums written back to sharded tensor, L1 corruption, expected: {}, actual: {}", expected_max_total_width, total_tensor_width_written_back); + auto override_runtime_arguments_callback = [writer_kernel_ids, all_worker_cores_ordered, cb_src2, cb_output_reshard]( const void* operation, diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp index 7563bd2a417..1047d162ef2 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/layernorm_op.cpp @@ -175,7 +175,7 @@ std::vector LayerNorm::create_output_tensors(const std::vector & auto shard_spec = input_tensor.shard_spec().value(); shard_spec.shape[1] = output_shape[3]; - CoreRange first_core_range(CoreCoord(0, 0), CoreCoord(1, 1)); + CoreRange first_core_range(CoreCoord(0, 0), CoreCoord(0, 0)); CoreRangeSet core_range_set({first_core_range}); shard_spec.grid = core_range_set; auto mem_config = this->output_mem_config;