diff --git a/tests/tt_metal/tt_metal/test_bcast.cpp b/tests/tt_metal/tt_metal/test_bcast.cpp index 6382ca5a7ab..b60816779f6 100644 --- a/tests/tt_metal/tt_metal/test_bcast.cpp +++ b/tests/tt_metal/tt_metal/test_bcast.cpp @@ -163,7 +163,7 @@ int main(int argc, char **argv) { ref_bcast_values[j] = bfloat16(bcast_1value+(j%7)).to_uint16(); // convert the reference broadcast tensor to tiled format tiled_bcast_values = convert_layout( - ref_bcast_values, ref_bcast_shape, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + ref_bcast_values, ref_bcast_shape, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES); TT_FATAL(tiled_bcast_values[0] == bcast_1value16, "Error"); // restore ref values and shape to 1 ref_bcast_shape[3] = 1; @@ -183,7 +183,7 @@ int main(int argc, char **argv) { // add something not too large but different between tiles ref_bcast_values[j] = bfloat16(bcast_1value+(j%7)).to_uint16(); tiled_bcast_values = convert_layout( - ref_bcast_values, ref_bcast_shape, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + ref_bcast_values, ref_bcast_shape, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES); num_bcast_tiles = NC*Wt; // restore values and shape to W } else if (bcast_dim == BcastDim::W) { @@ -194,7 +194,7 @@ int main(int argc, char **argv) { // add something not too large but different between tiles ref_bcast_values[j] = bfloat16(bcast_1value+(j%7)).to_uint16(); tiled_bcast_values = convert_layout( - ref_bcast_values, ref_bcast_shape, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + ref_bcast_values, ref_bcast_shape, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES); num_bcast_tiles = NC*Ht; } @@ -292,7 +292,7 @@ int main(int argc, char **argv) { tt_metal::detail::LaunchProgram(device, program); - // The kernel will view the input as TILED32_4FACES + // The kernel will view the input as TILED_NFACES vector result_vec; tt_metal::detail::ReadFromBuffer(dst_dram_buffer, result_vec); @@ -313,7 +313,7 @@ int main(int argc, char **argv) { // recover a linear view of input vector for consumption by gold_ function auto u16_src0_vec = u16_from_u32_vector(src0_vec); vector src_linear = convert_layout( - u16_src0_vec, shape, TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); + u16_src0_vec, shape, TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR); vector gold_added = gold_bcast_op( src_linear, shape, ref_bcast_values, bcast_dim, bcast_op); // result is uint16_t untilized @@ -321,7 +321,7 @@ int main(int argc, char **argv) { vector shapeR{shape[0], shape[1], shape[2], shape[3]}; auto gold_4f_u32 = u32_from_u16_vector( convert_layout( - gold_added, shapeR, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES)); + gold_added, shapeR, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES)); pass &= packed_uint32_t_vector_comparison(result_vec, gold_4f_u32, comparison_function, &argfail); if (!pass) diff --git a/tests/tt_metal/tt_metal/test_bfp4_conversion.cpp b/tests/tt_metal/tt_metal/test_bfp4_conversion.cpp index e11b8646997..2ad2d91a581 100644 --- a/tests/tt_metal/tt_metal/test_bfp4_conversion.cpp +++ b/tests/tt_metal/tt_metal/test_bfp4_conversion.cpp @@ -32,7 +32,7 @@ int main(int argc, char **argv) { } std::vector shape_vec = {1, num_tiles, 32, 32}; - std::vector tiled_fp32_vec = convert_layout(fp32_vec, shape_vec, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + std::vector tiled_fp32_vec = convert_layout(fp32_vec, shape_vec, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES); std::vector packed_bfp4b_tile_vec_rm_in = pack_fp32_vec_as_bfp4_tiles(fp32_vec, /*row_major_input=*/true, /*is_exp_a=*/false); std::vector unpacked_bfp4b_tile_vec_rm_out = unpack_bfp4_tiles_into_float_vec(packed_bfp4b_tile_vec_rm_in, /*row_major_output*/true, /*is_exp_a=*/false); @@ -44,8 +44,8 @@ int main(int argc, char **argv) { // //////////////////////////////////////////////////////////////////////////// // // Validation // //////////////////////////////////////////////////////////////////////////// - std::vector tiled_to_rm_fp32_vec = convert_layout(unpacked_bfp4b_tile_vec_tile_out, shape_vec, TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); - std::vector rm_to_tiled_fp32_vec = convert_layout(unpacked_bfp4b_tile_vec_rm_out, shape_vec, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + std::vector tiled_to_rm_fp32_vec = convert_layout(unpacked_bfp4b_tile_vec_tile_out, shape_vec, TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR); + std::vector rm_to_tiled_fp32_vec = convert_layout(unpacked_bfp4b_tile_vec_rm_out, shape_vec, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES); // Ensure that passing in row_major_input=true and row_major_output=true are inverses of row_major_input=false and row_major_output=false yield the same result pass &= (packed_bfp4b_tile_vec_rm_in == packed_bfp4b_tile_vec_tile_in); diff --git a/tests/tt_metal/tt_metal/test_bfp8_conversion.cpp b/tests/tt_metal/tt_metal/test_bfp8_conversion.cpp index 4e9d0fd8bca..08331ab8c9a 100644 --- a/tests/tt_metal/tt_metal/test_bfp8_conversion.cpp +++ b/tests/tt_metal/tt_metal/test_bfp8_conversion.cpp @@ -32,7 +32,7 @@ int main(int argc, char **argv) { } std::vector shape_vec = {1, 1, 32, 32}; - std::vector tiled_fp32_vec = convert_layout(fp32_vec, shape_vec, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + std::vector tiled_fp32_vec = convert_layout(fp32_vec, shape_vec, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES); std::vector packed_bfp8b_tile_vec_rm_in = pack_fp32_vec_as_bfp8_tiles(fp32_vec, /*row_major_input=*/true, /*is_exp_a=*/false); std::vector unpacked_bfp8b_tile_vec_rm_out = unpack_bfp8_tiles_into_float_vec(packed_bfp8b_tile_vec_rm_in, /*row_major_output*/true, /*is_exp_a=*/false); @@ -44,8 +44,8 @@ int main(int argc, char **argv) { // //////////////////////////////////////////////////////////////////////////// // // Validation // //////////////////////////////////////////////////////////////////////////// - std::vector tiled_to_rm_fp32_vec = convert_layout(unpacked_bfp8b_tile_vec_tile_out, shape_vec, TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); - std::vector rm_to_tiled_fp32_vec = convert_layout(unpacked_bfp8b_tile_vec_rm_out, shape_vec, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + std::vector tiled_to_rm_fp32_vec = convert_layout(unpacked_bfp8b_tile_vec_tile_out, shape_vec, TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR); + std::vector rm_to_tiled_fp32_vec = convert_layout(unpacked_bfp8b_tile_vec_rm_out, shape_vec, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES); // Ensure that passing in row_major_input=true and row_major_output=true are inverses of row_major_input=false and row_major_output=false yield the same result pass &= (packed_bfp8b_tile_vec_rm_in == packed_bfp8b_tile_vec_tile_in); diff --git a/tests/tt_metal/tt_metal/test_bmm.cpp b/tests/tt_metal/tt_metal/test_bmm.cpp index 3ca50e825ae..21f021714cc 100644 --- a/tests/tt_metal/tt_metal/test_bmm.cpp +++ b/tests/tt_metal/tt_metal/test_bmm.cpp @@ -164,13 +164,13 @@ int main(int argc, char **argv) { vector shapeC = {1, B, Mt*32, Nt*32}; auto u16_src0_vec = u16_from_u32_vector(src0_vec); auto u16_src1_vec = u16_from_u32_vector(src1_vec); - vector src0_linear = convert_layout(u16_src0_vec, shapeA, TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); - vector src1_linear = convert_layout(u16_src1_vec, shapeB, TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); + vector src0_linear = convert_layout(u16_src0_vec, shapeA, TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR); + vector src1_linear = convert_layout(u16_src1_vec, shapeB, TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR); vector ref_bmm = gold_bmm(shapeA, src0_linear, shapeB, src1_linear); // Tilize gold from row major and convert to pairs (uint32_t) auto gold_4f_u32 = u32_from_u16_vector( convert_layout( - ref_bmm, shapeC, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES)); + ref_bmm, shapeC, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES)); pass &= packed_uint32_t_vector_comparison(result_vec, gold_4f_u32, comparison_function, &argfail); if (!pass) diff --git a/tests/tt_metal/tt_metal/test_transpose_hc.cpp b/tests/tt_metal/tt_metal/test_transpose_hc.cpp index 8848a97a275..060ec8ac4ca 100644 --- a/tests/tt_metal/tt_metal/test_transpose_hc.cpp +++ b/tests/tt_metal/tt_metal/test_transpose_hc.cpp @@ -184,12 +184,12 @@ int main(int argc, char **argv) { }; // recover a linear view of input vector for consumption by gold_ function - vector src_linear = convert_layout(src_4f_16, shape, TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); + vector src_linear = convert_layout(src_4f_16, shape, TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR); vector gold_reduced = gold_transpose_hc(src_linear, shape); // result is uint16_t untilized // Tilize from row major and convert to pairs (uint32_t) vector shapeR{shape[0], shape[2], shape[1], shape[3]}; - auto gold_16_4f = convert_layout(gold_reduced, shapeR, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + auto gold_16_4f = convert_layout(gold_reduced, shapeR, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES); auto gold_4f_u32 = u32_from_u16_vector(gold_16_4f); auto u16_result = u16_from_u32_vector(result_vec); diff --git a/tests/tt_metal/tt_metal/unit_tests/compute/test_reduce.cpp b/tests/tt_metal/tt_metal/unit_tests/compute/test_reduce.cpp index cc565be54fe..961d5fd111c 100644 --- a/tests/tt_metal/tt_metal/unit_tests/compute/test_reduce.cpp +++ b/tests/tt_metal/tt_metal/unit_tests/compute/test_reduce.cpp @@ -326,7 +326,7 @@ void run_single_core_reduce_program(tt_metal::Device* device, const ReduceConfig tt_metal::detail::LaunchProgram(device, program); - // The kernel will view the input as TILED32_4FACES + // The kernel will view the input as TILED_NFACES std::vector result_vec; tt_metal::detail::ReadFromBuffer(dst_dram_buffer, result_vec); @@ -353,11 +353,11 @@ void run_single_core_reduce_program(tt_metal::Device* device, const ReduceConfig } } // recover a linear view of input vector for consumption by gold_ function - std::vector src_linear = convert_layout(u16_src0_vec, test_config.shape, TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); + std::vector src_linear = convert_layout(u16_src0_vec, test_config.shape, TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR); std::vector gold_reduced = test_config.golden_function(src_linear, test_config.shape, scaler, uint8_t(test_config.reduce_type), true); // result is uint16_t untilized // Tilize from row major and convert to pairs (uint32_t) - auto gold_4f_u32 = u32_from_u16_vector(convert_layout(gold_reduced, test_config.result_shape, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES)); + auto gold_4f_u32 = u32_from_u16_vector(convert_layout(gold_reduced, test_config.result_shape, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES)); bool pass = packed_uint32_t_vector_comparison(result_vec, gold_4f_u32, comparison_function, &argfail); if (!pass) diff --git a/tests/tt_metal/tt_metal/unit_tests/compute/test_transpose.cpp b/tests/tt_metal/tt_metal/unit_tests/compute/test_transpose.cpp index 6aef7e62f9e..c6b0ccc87dd 100644 --- a/tests/tt_metal/tt_metal/unit_tests/compute/test_transpose.cpp +++ b/tests/tt_metal/tt_metal/unit_tests/compute/test_transpose.cpp @@ -52,13 +52,13 @@ void validate_transpose_wh(const std::vector &src_vec, const std::vect // recover a linear view of input vector for consumption by gold_ function auto u16_src0_vec = u16_from_u32_vector(src_vec); - vector src_linear = convert_layout(u16_src0_vec, shape, TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); + vector src_linear = convert_layout(u16_src0_vec, shape, TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR); vector gold_reduced = gold_transpose_wh(src_linear, shape); // result is uint16_t untilized // Tilize from row major and convert to pairs (uint32_t) TT_FATAL(shape.size() == 4, "Error"); vector shapeR{shape[0], shape[1], shape[3], shape[2]}; - auto gold_4f_u32 = u32_from_u16_vector(convert_layout(gold_reduced, shapeR, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES)); + auto gold_4f_u32 = u32_from_u16_vector(convert_layout(gold_reduced, shapeR, TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES)); bool pass = packed_uint32_t_vector_comparison(result_vec, gold_4f_u32, comparison_function, &argfail); if (not pass) { diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index a8dc46ea819..15cfcbb2d77 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -7,7 +7,528 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import skip_for_grayskull, is_wormhole_b0, is_grayskull, is_blackhole +from models.utility_functions import skip_for_grayskull, is_wormhole_b0, is_grayskull, is_blackhole, run_for_wormhole_b0 + + +def find_max_subblock(out_block_h, out_block_w): + max_product = 0 + best_h = 1 + best_w = 1 + + for h in range(1, out_block_h + 1): + if out_block_h % h == 0: # h is a divisor of out_block_h + for w in range(1, out_block_w + 1): + if out_block_w % w == 0 and h * w <= 8: # w is a divisor and product condition met + if h * w > max_product: + max_product = h * w + best_h = h + best_w = w + if out_block_w > best_w: + best_h = 1 + return best_h, best_w, max_product + + +@pytest.mark.parametrize("n", [1]) +@pytest.mark.parametrize("c", [2]) +@pytest.mark.parametrize("h", [71]) +@pytest.mark.parametrize("w", [35]) +@pytest.mark.parametrize("tile_h", [1, 2, 4, 8, 16, 32]) +@pytest.mark.parametrize("tile_w", [16, 32]) +def test_tiny_tiles(device, n, c, h, w, tile_h, tile_w): + torch.manual_seed(0) + torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16) + input_tensor = ttnn.from_torch( + torch_input_tensor, + tile=(tile_h, tile_w), + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + output_tensor = ttnn.to_torch(input_tensor) + assert_with_pcc(torch_input_tensor, output_tensor, 1) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("b", [8]) +@pytest.mark.parametrize("h", [4]) +@pytest.mark.parametrize("m", [256]) +@pytest.mark.parametrize("k", [256]) +@pytest.mark.parametrize("n", [256]) +@pytest.mark.parametrize("tile_h", [16, 32]) +@pytest.mark.parametrize("tile_w", [16, 32]) +@pytest.mark.parametrize("in0_sharded", [True, False]) +@pytest.mark.parametrize("in1_sharded", [True, False]) +@pytest.mark.parametrize("out_sharded", [True, False]) +def test_matmul_reuse_config_sharded_tiny_tile( + device, b, h, m, k, n, tile_h, tile_w, in0_sharded, in1_sharded, out_sharded +): + torch.manual_seed(0) + + grid_size = (b, h) + + in0 = torch.ones([b, h, m, k]).bfloat16().float() + in1 = torch.randn([b, h, k, n]).bfloat16().float() + + if in0_sharded: + in0_memory_config = ttnn.create_sharded_memory_config( + (b, h, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in0_memory_config = ttnn.L1_MEMORY_CONFIG + in0_t = ttnn.from_torch( + in0, + tile=(tile_h, 32), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in0_memory_config, + ) + + if in1_sharded: + in1_memory_config = ttnn.create_sharded_memory_config( + (b, h, k, n), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in1_memory_config = ttnn.L1_MEMORY_CONFIG + in1_t = ttnn.from_torch( + in1, + tile=(32, tile_w), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in1_memory_config, + ) + + out_block_h = m // tile_h + out_block_w = n // tile_w + out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w) + + program_config = ttnn.MatmulMultiCoreReuseProgramConfig( + compute_with_storage_grid_size=grid_size, + in0_block_w=k // 32, + out_subblock_h=out_subblock_h, + out_subblock_w=out_subblock_w, + per_core_M=out_block_h, + per_core_N=out_block_w, + ) + if out_sharded: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + else: + out_mem_config = ttnn.L1_MEMORY_CONFIG + # override the tile width for later ops + if out_sharded and tile_h <= 16: + output_tile = ttnn.Tile([tile_h, 32]) + else: + output_tile = ttnn.Tile([tile_h, tile_w]) + output_t = ttnn.matmul( + in0_t, in1_t, program_config=program_config, memory_config=out_mem_config, output_tile=output_tile + ) + output_tensor = ttnn.to_torch(output_t) + pt_out = in0 @ in1 + + assert_with_pcc(pt_out, output_tensor, 0.999) + + +def pad_to_dram_banks(num, tile_w, lcm=32 * 12): + remainder = num % lcm + if remainder == 0: + return num + padding_needed = lcm - remainder + padded_number = num + padding_needed + return padded_number + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("k", [8192]) +@pytest.mark.parametrize("n", [1280]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("grid_size", [(8, 1)]) +@pytest.mark.parametrize("tile_h", [16, 32]) +@pytest.mark.parametrize("tile_w", [16, 32]) +def test_matmul_in1_dram_sharded_tiny_tile(device, k, n, has_bias, grid_size, tile_h, tile_w): + # PCC issue when height not equal to tile height + m = tile_h + if is_grayskull(): + n_padded = n + num_banks = 8 + else: + num_banks = 12 + n_padded = pad_to_dram_banks(n, tile_w, tile_w * num_banks) + + in0_shape = [1, 1, m, k] + in1_shape = [1, 1, k, n] + in1_shard_shape = [k, n_padded // num_banks] + bias_shape = [1, 1, n] + bias_shard_shape = [tile_h, n_padded // num_banks] + num_cores = grid_size[0] * grid_size[1] + + in0_block_w = k // num_cores // 32 + out_block_h = m // tile_h + out_block_w = n // num_cores // tile_w + + sharded_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + + in0 = torch.randn(in0_shape).bfloat16().float() + in1 = torch.randn(in1_shape).bfloat16().float() + + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + in0_t = ttnn.from_torch( + in0, + tile=(tile_h, 32), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in0_memory_config, + ) + in1_shard_grid = ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1) + in1_shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), in1_shard_grid)}) + in1_shard_spec = ttnn.ShardSpec(in1_shard_grid, in1_shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False) + in1_memory_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.DRAM, in1_shard_spec) + in1_t = ttnn.from_torch( + in1, + tile=(32, tile_w), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in1_memory_config, + ) + + if has_bias: + bias = torch.randn(bias_shape).bfloat16().float() + bias_padded = bias.unsqueeze(2) + bias_padded = torch.nn.functional.pad(bias_padded, (0, 0, 0, tile_h - bias_padded.size(2)), "constant", 0) + bias_shard_grid = ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1) + bias_shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), bias_shard_grid)}) + bias_shard_spec = ttnn.ShardSpec(bias_shard_grid, bias_shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False) + bias_mem_config = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.DRAM, bias_shard_spec + ) + bias_t = ttnn.from_torch( + bias_padded, + tile=(tile_h, tile_w), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=bias_mem_config, + ) + + program_config = ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( + in0_block_w=in0_block_w // 4, + per_core_M=out_block_h, + per_core_N=out_block_w, + fused_activation=None, + ) + + if is_grayskull(): + compute_kernel_config = ttnn.GrayskullComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + ) + else: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + if has_bias: + output_t = ttnn.linear( + in0_t, + in1_t, + bias=bias_t, + program_config=program_config, + memory_config=sharded_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + output_tile=ttnn.Tile([tile_h, 32]) if tile_h <= 16 else ttnn.Tile([tile_h, tile_w]), + ) + else: + output_t = ttnn.matmul( + in0_t, + in1_t, + program_config=program_config, + memory_config=sharded_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + output_tile=ttnn.Tile([tile_h, 32]) if tile_h <= 16 else ttnn.Tile([tile_h, tile_w]), + ) + output_tensor = ttnn.to_torch(output_t) + pt_out = in0 @ in1 + if has_bias: + pt_out += bias + + assert_with_pcc(pt_out, output_tensor, 0.999) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("m", [1536]) +@pytest.mark.parametrize("k", [1024]) +@pytest.mark.parametrize("n", [3072]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("grid_size", [(8, 4)]) +@pytest.mark.parametrize("tile_h", [16, 32]) +@pytest.mark.parametrize("tile_w", [16, 32]) +@pytest.mark.parametrize("in0_sharded", [True, False]) +@pytest.mark.parametrize("out_sharded", [True, False]) +def test_matmul_2d_tiny_tile(device, m, k, n, has_bias, grid_size, tile_h, tile_w, in0_sharded, out_sharded): + in0_shape = [1, 1, m, k] + in1_shape = [1, 1, k, n] + bias_shape = [1, 1, n] + + in0_block_w = k // grid_size[0] // 32 + out_block_h = m // grid_size[1] // tile_h + out_block_w = n // grid_size[0] // tile_w + out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w) + + in0 = torch.randn(in0_shape).bfloat16().float() + in1 = torch.randn(in1_shape).bfloat16().float() + + if in0_sharded: + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.BLOCK, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in0_memory_config = ttnn.L1_MEMORY_CONFIG + in0_t = ttnn.from_torch( + in0, + tile=(tile_h, 32), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in0_memory_config, + ) + in1_t = ttnn.from_torch( + in1, + tile=(32, tile_w), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + if has_bias: + bias = torch.randn(bias_shape).bfloat16().float() + bias_padded = bias.unsqueeze(2) + bias_padded = torch.nn.functional.pad(bias_padded, (0, 0, 0, tile_h - bias_padded.size(2)), "constant", 0) + bias_t = ttnn.from_torch( + bias_padded, + tile=(tile_h, tile_w), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( + compute_with_storage_grid_size=grid_size, + in0_block_w=in0_block_w, + out_subblock_h=out_subblock_h, + out_subblock_w=out_subblock_w, + per_core_M=out_block_h, + per_core_N=out_block_w, + transpose_mcast=False, + fused_activation=None, + ) + + if is_grayskull(): + compute_kernel_config = ttnn.GrayskullComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + ) + else: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=True, + ) + if out_sharded: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + else: + out_mem_config = ttnn.L1_MEMORY_CONFIG + if out_sharded: + output_tile = ttnn.Tile([tile_h, 32]) if tile_h <= 16 else ttnn.Tile([tile_h, tile_w]) + else: + output_tile = ttnn.Tile([tile_h, tile_w]) + if has_bias: + output_t = ttnn.linear( + in0_t, + in1_t, + bias=bias_t, + program_config=program_config, + memory_config=out_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + output_tile=output_tile, + ) + else: + output_t = ttnn.matmul( + in0_t, + in1_t, + program_config=program_config, + memory_config=out_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + output_tile=output_tile, + ) + output_tensor = ttnn.to_torch(output_t) + pt_out = in0 @ in1 + if has_bias: + pt_out += bias + + assert_with_pcc(pt_out, output_tensor, 0.999) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("m", [256]) +@pytest.mark.parametrize("k", [1024]) +@pytest.mark.parametrize("n", [1024]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("grid_size", [(8, 4)]) +@pytest.mark.parametrize("tile_h", [16, 32]) +@pytest.mark.parametrize("tile_w", [16, 32]) +@pytest.mark.parametrize("in0_sharded", [True, False]) +@pytest.mark.parametrize("out_sharded", [True, False]) +def test_matmul_1d_tiny_tile(device, m, k, n, has_bias, grid_size, tile_h, tile_w, in0_sharded, out_sharded): + in0_shape = [1, 1, m, k] + in1_shape = [1, 1, k, n] + bias_shape = [1, 1, n] + + num_cores = grid_size[0] * grid_size[1] + + in0_block_w = k // num_cores // 32 + out_block_h = m // tile_h + out_block_w = n // num_cores // tile_w + out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w) + + in0 = torch.randn(in0_shape).bfloat16().float() + in1 = torch.randn(in1_shape).bfloat16().float() + + if in0_sharded: + in0_memory_config = ttnn.create_sharded_memory_config( + (1, 1, m, k), + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + else: + in0_memory_config = ttnn.L1_MEMORY_CONFIG + in0_t = ttnn.from_torch( + in0, + tile=(tile_h, 32), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=in0_memory_config, + ) + in1_t = ttnn.from_torch( + in1, + tile=(32, tile_w), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + if has_bias: + bias = torch.randn(bias_shape).bfloat16().float() + bias_padded = bias.unsqueeze(2) + bias_padded = torch.nn.functional.pad(bias_padded, (0, 0, 0, tile_h - bias_padded.size(2)), "constant", 0) + bias_t = ttnn.from_torch( + bias_padded, + tile=(tile_h, tile_w), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=grid_size, + in0_block_w=in0_block_w, + out_subblock_h=out_subblock_h, + out_subblock_w=out_subblock_w, + per_core_M=out_block_h, + per_core_N=out_block_w, + fuse_batch=True, + fused_activation=None, + mcast_in0=True, + ) + + if is_grayskull(): + compute_kernel_config = ttnn.GrayskullComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + ) + else: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=True, + ) + if out_sharded: + out_mem_config = ttnn.MemoryConfig( + memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, + buffer_type=ttnn.BufferType.L1, + ) + else: + out_mem_config = ttnn.L1_MEMORY_CONFIG + if out_sharded: + output_tile = ttnn.Tile([tile_h, 32]) if tile_h <= 16 else ttnn.Tile([tile_h, tile_w]) + else: + output_tile = ttnn.Tile([tile_h, tile_w]) + if has_bias: + output_t = ttnn.linear( + in0_t, + in1_t, + bias=bias_t, + program_config=program_config, + memory_config=out_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + output_tile=output_tile, + ) + else: + output_t = ttnn.matmul( + in0_t, + in1_t, + program_config=program_config, + memory_config=out_mem_config, + dtype=ttnn.bfloat16, + compute_kernel_config=compute_kernel_config, + output_tile=output_tile, + ) + output_tensor = ttnn.to_torch(output_t) + pt_out = in0 @ in1 + if has_bias: + pt_out += bias + + assert_with_pcc(pt_out, output_tensor, 0.999) # fmt: off diff --git a/tt_metal/common/test_tiles.hpp b/tt_metal/common/test_tiles.hpp index eed4f2450a1..083093cbed8 100644 --- a/tt_metal/common/test_tiles.hpp +++ b/tt_metal/common/test_tiles.hpp @@ -10,39 +10,51 @@ #include #include +#include +#include "tt_metal/common/constants.hpp" #include "tt_metal/common/assert.hpp" #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" #include "math.hpp" enum TensorLayout { LIN_ROW_MAJOR = 0, // standard element-wise row-major - TILED32_SWIZZLED = 1, // row-major of tiles 32x32, each tile is row-major-swizzled - TILED32_4FACES = 2, // rowm major of tiles 32x32, each tile is 4 faces, each face is row-major, faces are swizzled + TILED_SWIZZLED = 1, // row-major of tiles, each tile is row-major-swizzled + TILED_NFACES = 2, // row-major of tiles, each tile is N (N = 1, 2, or 4) faces, each face is row-major, faces are swizzled }; template typename BufferType> -std::vector convert_to_tile_layout(const BufferType& data) { +std::vector convert_to_tile_layout( + const BufferType& data, + const std::optional>& tile_shape = std::nullopt, + const std::optional>& face_shape = std::nullopt) { ZoneScoped; std::vector result; - TT_ASSERT(data.size() / (32 * 32) > 0); - TT_ASSERT(data.size() % (32 * 32) == 0); - int num_tiles = data.size() / (32 * 32); + result.reserve(data.size()); + auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; + auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; + auto face_H = face_shape.has_value() ? face_shape.value()[0] : tt::constants::FACE_HEIGHT; + auto face_W = face_shape.has_value() ? face_shape.value()[1] : tt::constants::FACE_WIDTH; + auto tile_HW = tile_H * tile_W; + auto face_HW = face_H * face_W; + TT_ASSERT(data.size() / tile_HW > 0); + TT_ASSERT(data.size() % tile_HW == 0); + int num_tiles = data.size() / tile_HW; for(int tile_idx = 0; tile_idx < num_tiles; tile_idx++) { std::vector top_left; std::vector top_right; std::vector bottom_left; std::vector bottom_right; - int index = tile_idx * (32 * 32); - for(int row = 0; row < 32; row++) { - for(int col = 0; col < 32; col++) { - if(row < 16 and col < 16) { + int index = tile_idx * tile_HW; + for(int row = 0; row < tile_H; row++) { + for(int col = 0; col < tile_W; col++) { + if(row < face_H and col < face_W) { top_left.push_back(data[index]); - } else if(row < 16 and col >= 16) { + } else if(row < face_H and col >= face_W) { top_right.push_back(data[index]); - } else if(row >= 16 and col < 16) { + } else if(row >= face_H and col < face_W) { bottom_left.push_back(data[index]); - } else if(row >= 16 and col >= 16) { + } else if(row >= face_H and col >= face_W) { bottom_right.push_back(data[index]); } else { TT_ASSERT(false); @@ -50,10 +62,10 @@ std::vector convert_to_tile_layout(const BufferType& data) { index++; } } - TT_ASSERT(top_left.size() == 16 * 16); - TT_ASSERT(top_right.size() == 16 * 16); - TT_ASSERT(bottom_left.size() == 16 * 16); - TT_ASSERT(bottom_right.size() == 16 * 16); + TT_ASSERT(top_left.size() == face_HW); + TT_ASSERT((top_right.size() == 0) or (top_right.size() == face_HW)); + TT_ASSERT((bottom_left.size() == 0) or (bottom_left.size() == face_HW)); + TT_ASSERT((bottom_right.size() == 0) or (bottom_right.size() == face_HW)); result.insert(result.end(), top_left.begin(), top_left.end()); result.insert(result.end(), top_right.begin(), top_right.end()); @@ -65,20 +77,32 @@ std::vector convert_to_tile_layout(const BufferType& data) { } template typename BufferTyp> -std::vector convert_to_flat_layout(const BufferTyp& data) { +std::vector convert_to_flat_layout( + const BufferTyp& data, + const std::optional>& tile_shape = std::nullopt, + const std::optional>& face_shape = std::nullopt) { ZoneScoped; std::vector result; - TT_ASSERT(data.size() / (32 * 32) > 0); - TT_ASSERT(data.size() % (32 * 32) == 0); - int num_tiles = data.size() / (32 * 32); + result.reserve(data.size()); + auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; + auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; + auto face_H = face_shape.has_value() ? face_shape.value()[0] : tt::constants::FACE_HEIGHT; + auto face_W = face_shape.has_value() ? face_shape.value()[1] : tt::constants::FACE_WIDTH; + auto tile_HW = tile_H * tile_W; + auto face_HW = face_H * face_W; + auto num_faces_row = tile_W / face_W; + auto num_faces_col = tile_H / face_H; + TT_ASSERT(data.size() / tile_HW > 0); + TT_ASSERT(data.size() % tile_HW == 0); + int num_tiles = data.size() / tile_HW; for(int tile_idx = 0; tile_idx < num_tiles; tile_idx++) { - int tile_start = tile_idx * (32 * 32); - for(int face_y = 0; face_y < 2; face_y++) { - for(int row = 0; row < 16; row++) { - int start = tile_start + face_y * (16 * 32) + row * 16; - for(int face_x = 0; face_x < 2; face_x++) { - int offset = face_x * (16 * 16); - for(int col = offset; col < offset + 16; col++) { + int tile_start = tile_idx * tile_HW; + for(int face_y = 0; face_y < num_faces_col; face_y++) { + for(int row = 0; row < face_H; row++) { + int start = tile_start + face_y * (face_H * tile_W) + row * face_W; + for(int face_x = 0; face_x < num_faces_row; face_x++) { + int offset = face_x * face_HW; + for(int col = offset; col < offset + face_W; col++) { result.push_back(data[start + col]); } } @@ -91,9 +115,12 @@ std::vector convert_to_flat_layout(const BufferTyp& data) { // Converts a 32-swizzled tilized row-major tensor to a linear 32-zero-padded row-major tensor template typename BufferType> -inline std::vector untilize_nchw(const BufferType& in, const std::vector& shape) { +inline std::vector untilize_nchw(const BufferType& in, const std::vector& shape, const std::optional>& tile_shape = std::nullopt) { ZoneScoped; - TT_ASSERT(shape[shape.size() - 2] % 32 == 0 && shape[shape.size() - 1] % 32 == 0); + auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; + auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; + + TT_ASSERT(shape[shape.size() - 2] % tile_H == 0 && shape[shape.size() - 1] % tile_W == 0); std::vector result; // Untilize into row major @@ -105,13 +132,13 @@ inline std::vector untilize_nchw(const BufferType& in, const std::vector typename BufferType> -inline std::vector tilize_nchw(const BufferType& in_rowmajor, const std::vector& shape) { +inline std::vector tilize_nchw(const BufferType& in_rowmajor, const std::vector& shape, const std::optional>& tile_shape = std::nullopt) { ZoneScoped; int H = shape[shape.size() - 2], W = shape[shape.size() - 1]; auto batch_size = 1; @@ -138,19 +167,21 @@ inline std::vector tilize_nchw(const BufferType& in_rowmajor, const std::v batch_size *= shape[i]; } int input_volume = batch_size * H * W; - int OH = round_up_to_mul32(H); - int OW = round_up_to_mul32(W); + auto tile_H = tile_shape.has_value() ? tile_shape.value()[0] : tt::constants::TILE_HEIGHT; + auto tile_W = tile_shape.has_value() ? tile_shape.value()[1] : tt::constants::TILE_WIDTH; + int OH = round_up_to_tile(H, tile_H); + int OW = round_up_to_tile(W, tile_W); std::vector tilized_result; tilized_result.resize(batch_size * OH * OW); std::fill(tilized_result.begin(), tilized_result.end(), 0); int out_index = 0; for (auto batch_index = 0; batch_index < batch_size; batch_index++) { - for (int hs32 = 0; hs32 < H; hs32 += 32) { - for (int ws32 = 0; ws32 < W; ws32 += 32) { - for (int h32 = 0; h32 < 32; h32++) { - for (int w32 = 0; w32 < 32; w32++) { - auto w = w32 + ws32; - auto h = h32 + hs32; + for (int hs = 0; hs < H; hs += tile_H) { + for (int ws = 0; ws < W; ws += tile_W) { + for (int ht = 0; ht < tile_H; ht++) { + for (int wt = 0; wt < tile_W; wt++) { + auto w = wt + ws; + auto h = ht + hs; auto in_offs = w + h * W + batch_index * H * W; auto val = (w >= W || h >= H || in_offs >= input_volume) ? 0 : in_rowmajor[in_offs]; int out_w = (out_index % OW); @@ -189,32 +220,37 @@ struct TensAddr { template typename BufferType> inline std::vector convert_layout( - const BufferType& inp, const std::vector& shape, TensorLayout inL, TensorLayout outL) { + const BufferType& inp, + const std::vector& shape, + TensorLayout inL, + TensorLayout outL, + const std::optional>& tile_shape = std::nullopt, + const std::optional>& face_shape = std::nullopt) { ZoneScoped; switch (inL) { - case TILED32_SWIZZLED: - if (outL == TILED32_4FACES) { - return convert_to_tile_layout(inp); + case TILED_SWIZZLED: + if (outL == TILED_NFACES) { + return convert_to_tile_layout(inp, tile_shape, face_shape); } else if (outL == LIN_ROW_MAJOR) { - return untilize_nchw(inp, shape); + return untilize_nchw(inp, shape, tile_shape); } else TT_ASSERT(false && "Unsupported conversion."); break; case LIN_ROW_MAJOR: - if (outL == TILED32_SWIZZLED) { - return tilize_nchw(inp, shape); - } else if (outL == TILED32_4FACES) { - auto swiz32 = convert_layout(inp, shape, inL, TILED32_SWIZZLED); - return convert_layout(swiz32, shape, TILED32_SWIZZLED, outL); + if (outL == TILED_SWIZZLED) { + return tilize_nchw(inp, shape, tile_shape); + } else if (outL == TILED_NFACES) { + auto swiz32 = convert_layout(inp, shape, inL, TILED_SWIZZLED, tile_shape, face_shape); + return convert_layout(swiz32, shape, TILED_SWIZZLED, outL, tile_shape, face_shape); } else TT_ASSERT(false && "Unsupported conversion."); break; - case TILED32_4FACES: - if (outL == TILED32_SWIZZLED) { - return convert_to_flat_layout(inp); + case TILED_NFACES: + if (outL == TILED_SWIZZLED) { + return convert_to_flat_layout(inp, tile_shape, face_shape); } else if (outL == LIN_ROW_MAJOR) { - auto swiz32 = convert_layout(inp, shape, inL, TILED32_SWIZZLED); - return untilize_nchw(swiz32, shape); + auto swiz32 = convert_layout(inp, shape, inL, TILED_SWIZZLED, tile_shape, face_shape); + return untilize_nchw(swiz32, shape, tile_shape); } else { TT_ASSERT(false && "Unsupported conversion"); } diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/common/chlkc_list.h b/tt_metal/hw/ckernels/wormhole_b0/metal/common/chlkc_list.h index 9795af7d63b..5fad25c0062 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/common/chlkc_list.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/common/chlkc_list.h @@ -14,6 +14,7 @@ using namespace ckernel; #ifdef UCK_CHLKC_MATH #include "chlkc_unpack_data_format.h" +#include "chlkc_unpack_tile_dims.h" #include "chlkc_math_fidelity.h" #include "chlkc_math_approx_mode.h" #include "chlkc_dst_accum_mode.h" @@ -22,12 +23,14 @@ using namespace ckernel; #ifdef UCK_CHLKC_PACK #include "chlkc_pack_data_format.h" +#include "chlkc_pack_tile_dims.h" #include "chlkc_dst_accum_mode.h" #include "chlkc_pack.cpp" #endif #ifdef UCK_CHLKC_UNPACK #include "chlkc_unpack_data_format.h" +#include "chlkc_unpack_tile_dims.h" #include "chlkc_dst_accum_mode.h" #include "chlkc_unpack.cpp" #endif diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_matmul_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_matmul_api.h index 2950d33bcf8..4f1c045cf4e 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_matmul_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_matmul_api.h @@ -21,7 +21,9 @@ inline void llk_math_matmul_init( const std::uint32_t in0_id = get_operand_id(operandA); const std::uint32_t in1_id = get_operand_id(operandB); - const bool partial_face = get_operand_partial_face(in0_id); + // TODO: this flags is only for computing 8x32 tile shape, although current impl assumes the in0 tile is still 16x32. + // We should remove this flag in the furture and add impl for 8x32 input tile shape + const bool partial_face = 0; const std::uint32_t in0_tile_r_dim = get_operand_tile_r_dim(in0_id); const std::uint32_t in0_tile_c_dim = get_operand_tile_c_dim(in0_id); diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_unpack_AB_matmul_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_unpack_AB_matmul_api.h index ec6cb98277a..ced21fa1c79 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_unpack_AB_matmul_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_unpack_AB_matmul_api.h @@ -6,6 +6,7 @@ #include "llk_unpack_AB_matmul.h" #include "llk_unpack_common_api.h" + /************************************************************************* * LLK UNPACK AB MATMUL *************************************************************************/ @@ -111,15 +112,16 @@ inline void llk_unpack_AB_matmul( const std::uint32_t operandA_id = get_operand_id(operandA); const std::uint32_t operandB_id = get_operand_id(operandB); - const std::uint32_t unpA_face_r_dim = get_operand_face_r_dim(operandB_id); // In1/InB -> srcA - const std::uint32_t unpB_face_r_dim = get_operand_face_r_dim(operandA_id); // In0/InA -> srcB - const bool partial_face_a = get_operand_partial_face(operandA_id); - const bool partial_face_b = get_operand_partial_face(operandB_id); + const std::uint32_t unpA_face_r_dim = get_operand_face_r_dim(operandB_id); // In1/InB -> srcA - unused in lower API + const std::uint32_t unpB_face_r_dim = get_operand_face_r_dim(operandA_id); // In0/InA -> srcB - unused in lower API + + // TODO: remove partial_face flag, as this is easily to be confused with the partial face flag in math kernel + const bool partial_face_a = get_operand_partial_face(operandB_id); // In1/InB -> srcA + const bool partial_face_b = get_operand_partial_face(operandA_id); // In0/InA -> srcB` std::uint32_t base_address_a = cb_interface[operandA_id].fifo_rd_ptr - 1; std::uint32_t base_address_b = cb_interface[operandB_id].fifo_rd_ptr - 1; - std::uint32_t tile_size_a = cb_interface[operandA_id].fifo_page_size; std::uint32_t tile_size_b = cb_interface[operandB_id].fifo_page_size; diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_operands.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_operands.h index ea113ce5fa0..91f164495aa 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_operands.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_operands.h @@ -23,30 +23,30 @@ inline const uint32_t get_operand_dst_format(const std::uint32_t operand_id) inline const uint32_t get_operand_num_faces(const std::uint32_t operand_id) { - return 4; + return (uint32_t)unpack_tile_num_faces[operand_id]; } inline const uint32_t get_operand_partial_face(const std::uint32_t operand_id) { - return 0; + return (uint32_t)unpack_partial_face[operand_id]; } inline const uint32_t get_operand_face_r_dim(const std::uint32_t operand_id) { - return 16; + return (uint32_t)unpack_tile_face_r_dim[operand_id]; } inline const uint32_t get_operand_narrow_tile(const std::uint32_t operand_id) { - return 0; + return (uint32_t)unpack_narrow_tile[operand_id]; } inline const uint32_t get_operand_tile_r_dim(const std::uint32_t operand_id) { - return 32; + return (uint32_t)unpack_tile_r_dim[operand_id]; } inline const uint32_t get_operand_tile_c_dim(const std::uint32_t operand_id) { - return 32; + return (uint32_t)unpack_tile_c_dim[operand_id]; } diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h index 74c71eb9751..71772bd323f 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h @@ -30,30 +30,30 @@ inline const unsigned char get_output_dst_format(const std::uint32_t output_id) inline const uint32_t get_output_num_faces(const std::uint32_t output_id) { - return 4; + return (uint32_t)pack_tile_num_faces[output_id]; } inline const uint32_t get_output_partial_face(const std::uint32_t output_id) { - return 0; + return (uint32_t)pack_partial_face[output_id]; } inline const uint32_t get_output_face_r_dim(const std::uint32_t output_id) { - return 16; + return (uint32_t)pack_tile_face_r_dim[output_id]; } inline const uint32_t get_output_narrow_tile(const std::uint32_t output_id) { - return 0; + return (uint32_t)pack_narrow_tile[output_id]; } inline const uint32_t get_output_tile_r_dim(const std::uint32_t output_id) { - return 32; + return (uint32_t)pack_tile_r_dim[output_id]; } inline const uint32_t get_output_tile_c_dim(const std::uint32_t output_id) { - return 32; + return (uint32_t)pack_tile_c_dim[output_id]; } diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index b983a09e325..b46cec89b7b 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -7,6 +7,7 @@ #if __has_include("chlkc_unpack_data_format.h") #include "chlkc_pack_data_format.h" #include "chlkc_unpack_data_format.h" +#include "chlkc_unpack_tile_dims.h" #define DATA_FORMATS_DEFINED #endif #if __has_include("generated_bank_to_noc_coord_mapping.h") @@ -211,24 +212,31 @@ constexpr static std::int32_t GET_TILE_SIZE(uint format) { }; } -FORCE_INLINE -constexpr static std::uint32_t MUL_WITH_TILE_SIZE(uint format, uint index) { +template +FORCE_INLINE constexpr static std::uint32_t MUL_WITH_TILE_SIZE(uint format, uint index) { + constexpr uint8_t datum_shift = (tile_hw == 1024) ? 10 : + (tile_hw == 512) ? 9 : + (tile_hw == 256) ? 8 : 10; + + constexpr uint8_t exp_shift = (tile_hw == 1024) ? 2 : + (tile_hw == 512) ? 1 : + (tile_hw == 256) ? 0 : 2; switch (format & 0x1F) { - case ((uint8_t)DataFormat::UInt8): return (index << 10); + case ((uint8_t)DataFormat::UInt8): return (index << datum_shift); case ((uint8_t)DataFormat::UInt16): case ((uint8_t)DataFormat::Float16): - case ((uint8_t)DataFormat::Float16_b): return (index << 11); + case ((uint8_t)DataFormat::Float16_b): return (index << (datum_shift + 1)); case ((uint8_t)DataFormat::Int32): case ((uint8_t)DataFormat::UInt32): - case ((uint8_t)DataFormat::Float32): return (index << 12); + case ((uint8_t)DataFormat::Float32): return (index << (datum_shift + 2)); case ((uint8_t)DataFormat::Bfp2): - case ((uint8_t)DataFormat::Bfp2_b): return ((index << 8) + (index << 6)); + case ((uint8_t)DataFormat::Bfp2_b): return ((index << (datum_shift - 2)) + (index << (4 + exp_shift))); case ((uint8_t)DataFormat::Bfp4): - case ((uint8_t)DataFormat::Bfp4_b): return ((index << 9) + (index << 6)); + case ((uint8_t)DataFormat::Bfp4_b): return ((index << (datum_shift - 1)) + (index << (4 + exp_shift))); case ((uint8_t)DataFormat::Bfp8): case ((uint8_t)DataFormat::Bfp8_b): // Keep default as Bfp8? - default: return ((index << 10) + (index << 6)); + default: return ((index << datum_shift) + (index << (4 + exp_shift))); }; } @@ -321,12 +329,22 @@ constexpr inline std::int32_t get_tile_size(const std::int32_t operand) { std::uint32_t input = operand; // L1 16B words - std::uint32_t num_words = GET_TILE_SIZE((uint)unpack_src_format[input]); + std::uint32_t num_words = (uint)unpack_tile_size[input]; // return bytes return num_words; } +constexpr inline uint32_t get_tile_hw(const std::int32_t operand) { + std::uint32_t input = operand; + return (uint32_t)unpack_tile_r_dim[input] * (uint32_t)unpack_tile_c_dim[input]; +} + +constexpr inline uint32_t get_tile_num_faces(const std::int32_t operand) { + std::uint32_t input = operand; + return (uint32_t)unpack_tile_num_faces[input]; +} + constexpr inline DataFormat get_dataformat(const std::int32_t operand) { return static_cast((uint)unpack_src_format[operand]); } @@ -892,7 +910,7 @@ struct InterleavedPow2AddrGen { } }; -template +template struct InterleavedAddrGenFast { uint32_t bank_base_address; // Base address for the whole tensor. // TODO: Remove page_size from argument list. This can be derived from data_format @@ -901,7 +919,7 @@ struct InterleavedAddrGenFast { FORCE_INLINE uint32_t get_addr(const uint32_t id, const uint32_t bank_offset_index, const uint32_t bank_index, const uint32_t offset = 0) const { - return MUL_WITH_TILE_SIZE((uint)this->data_format, bank_offset_index) + this->bank_base_address + offset + interleaved_addr_gen::get_bank_offset(bank_index); + return MUL_WITH_TILE_SIZE((uint)this->data_format, bank_offset_index) + this->bank_base_address + offset + interleaved_addr_gen::get_bank_offset(bank_index); } FORCE_INLINE @@ -1090,8 +1108,8 @@ FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedPow2 return s.get_noc_addr(id, offset); } -template -FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedAddrGenFast& s, uint32_t offset = 0) { +template +FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedAddrGenFast& s, uint32_t offset = 0) { /* Alternative API for getting the noc address when we are reading using a swizzled layout. This version assumes bank unit size can be arbitrary size. Use @@ -1116,9 +1134,9 @@ FORCE_INLINE void noc_async_read_page( s.noc_async_read_page(id, dst_local_l1_addr, offset); } -template +template FORCE_INLINE void noc_async_read_tile( - const uint32_t id, const InterleavedAddrGenFast& s, std::uint32_t dst_local_l1_addr, uint32_t offset = 0) { + const uint32_t id, const InterleavedAddrGenFast& s, std::uint32_t dst_local_l1_addr, uint32_t offset = 0) { /* Read requests - use static VC Read responses - assigned VCs dynamically @@ -1167,9 +1185,9 @@ void noc_async_write(std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr } } -template +template FORCE_INLINE void noc_async_write_tile( - const uint32_t id, const InterleavedAddrGenFast& s, std::uint32_t src_local_l1_addr) { + const uint32_t id, const InterleavedAddrGenFast& s, std::uint32_t src_local_l1_addr) { s.noc_async_write_tile(id, src_local_l1_addr); } diff --git a/tt_metal/impl/buffers/circular_buffer.cpp b/tt_metal/impl/buffers/circular_buffer.cpp index e4f1a180db7..d50e2141164 100644 --- a/tt_metal/impl/buffers/circular_buffer.cpp +++ b/tt_metal/impl/buffers/circular_buffer.cpp @@ -89,6 +89,13 @@ DataFormat CircularBuffer::data_format(uint32_t buffer_index) const { return this->config_.data_formats().at(buffer_index).value(); } +const std::optional& CircularBuffer::tile(uint32_t buffer_index) const { + if (not this->uses_buffer_index(buffer_index)) { + TT_THROW("Cannot access tile dims for buffer index {} because circular buffer is not configured on that index", buffer_index); + } + return this->config_.tiles().at(buffer_index); +} + uint32_t CircularBuffer::address() const { if (not locally_allocated_address_.has_value() and not this->globally_allocated()) { TT_THROW("Circular buffer has not been allocated, cannot request address at this time!"); diff --git a/tt_metal/impl/buffers/circular_buffer.hpp b/tt_metal/impl/buffers/circular_buffer.hpp index 8d9c44077b3..0e305aeb76c 100644 --- a/tt_metal/impl/buffers/circular_buffer.hpp +++ b/tt_metal/impl/buffers/circular_buffer.hpp @@ -36,6 +36,8 @@ class CircularBuffer { DataFormat data_format(uint32_t buffer_index) const; + const std::optional& tile(uint32_t buffer_index) const; + uint32_t address() const; bool is_on_logical_corerange(const CoreRange &logical_cr) const; diff --git a/tt_metal/impl/buffers/circular_buffer_types.hpp b/tt_metal/impl/buffers/circular_buffer_types.hpp index 939f93e4fba..b27fc39ce61 100644 --- a/tt_metal/impl/buffers/circular_buffer_types.hpp +++ b/tt_metal/impl/buffers/circular_buffer_types.hpp @@ -14,6 +14,7 @@ #include "tt_metal/common/tt_backend_api_types.hpp" #include "tt_metal/hostdevcommon/common_runtime_address_map.h" #include "tt_metal/impl/buffers/buffer.hpp" +#include "tt_metal/impl/tile/tile.hpp" namespace tt::tt_metal { @@ -99,6 +100,15 @@ class CircularBufferConfig { return *this; } + CircularBufferConfig set_tile_dims(uint8_t buffer_index, const Tile& tile) { + this->tiles_[buffer_index] = tile; + return *this; + } + + const std::array, NUM_CIRCULAR_BUFFERS> &tiles() const { + return this->tiles_; + } + uint32_t total_size() const { return this->total_size_; } std::optional globally_allocated_address() const { return this->globally_allocated_address_; } @@ -134,6 +144,7 @@ class CircularBufferConfig { std::optional globally_allocated_address_ = std::nullopt; std::array, NUM_CIRCULAR_BUFFERS> data_formats_; std::array, NUM_CIRCULAR_BUFFERS> page_sizes_; + std::array, NUM_CIRCULAR_BUFFERS> tiles_; std::unordered_set buffer_indices_; bool dynamic_cb_ = false; // `max_size_` is used to ensure that total size does not grow beyond associated buffer size diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index 482ed1278b1..d2de62a7f4b 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -591,6 +591,37 @@ void Program::set_cb_data_fmt(Device *device, const std::vector &crs, } } +void Program::set_cb_tile_dims(Device *device, const std::vector &crs, JitBuildOptions &build_options) const { + ZoneScoped; + for (const auto &logical_cr : crs) { + auto cbs_on_core = this->circular_buffers_on_corerange(logical_cr); + for (const auto &circular_buffer : cbs_on_core) { + for (auto buffer_index : circular_buffer->buffer_indices()) { + auto tile = circular_buffer->tile(buffer_index); + if (tile.has_value()) { + build_options.set_cb_tile_dims_all_cores( + static_cast(buffer_index), + tile->get_num_faces(), + tile->get_partial_face(), + tile->get_face_shape()[0], + tile->get_narrow_tile(), + tile->get_tile_shape()[0], + tile->get_tile_shape()[1]); + build_options.set_cb_tile_size_all_cores( + static_cast(buffer_index), + tile->get_tile_size(circular_buffer->data_format(buffer_index))); + } else { + Tile t; + build_options.set_cb_tile_size_all_cores( + static_cast(buffer_index), + t.get_tile_size(circular_buffer->data_format(buffer_index))); + } + + } + } + } +} + void Program::invalidate_compile() { for (auto &[device_id, compile_needed] : compile_needed_) { compile_needed = true; @@ -1018,6 +1049,7 @@ void Program::compile(Device *device, bool fd_bootloader_mode) { JitBuildOptions build_options(device->build_env()); kernel->set_build_options(build_options); this->set_cb_data_fmt(device, kernel->logical_coreranges(), build_options); + this->set_cb_tile_dims(device, kernel->logical_coreranges(), build_options); auto kernel_hash = KernelCompileHash(kernel, build_options, device->build_key()); std::string kernel_path_suffix = kernel->name() + "/" + std::to_string(kernel_hash) + "/"; diff --git a/tt_metal/impl/program/program.hpp b/tt_metal/impl/program/program.hpp index 3a2d6652317..0b7d82a2083 100644 --- a/tt_metal/impl/program/program.hpp +++ b/tt_metal/impl/program/program.hpp @@ -239,6 +239,8 @@ class Program { void set_cb_data_fmt( Device *device, const std::vector & crs, JitBuildOptions& build_options) const; + void set_cb_tile_dims( Device *device, const std::vector & crs, JitBuildOptions& build_options) const; + void update_kernel_groups(uint32_t programmable_core_type_index); uint32_t& get_program_config_size(uint32_t programmable_core_type_index); diff --git a/tt_metal/impl/tile/tile.hpp b/tt_metal/impl/tile/tile.hpp new file mode 100644 index 00000000000..578b150da20 --- /dev/null +++ b/tt_metal/impl/tile/tile.hpp @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "common/bfloat16.hpp" +#include "common/tt_backend_api_types.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/common/math.hpp" + +namespace tt { + +namespace tt_metal { + +constexpr std::array, 2>, 12> TILE_FACE_HW_CHOICES = {{ + // TODO: add other tile shapes once llk supported it + {{ {32, 32}, {16, 16} }}, + {{ {16, 32}, {16, 16} }}, + {{ {32, 16}, {16, 16} }}, + {{ {16, 16}, {16, 16} }}, + // this shapes are not supported yet on llk, just for host loopback + {{ {8, 32}, {8, 16} }}, + {{ {4, 32}, {4, 16} }}, + {{ {2, 32}, {2, 16} }}, + {{ {1, 32}, {1, 16} }}, + // this shapes are not supported yet on llk, just for host loopback + {{ {8, 16}, {8, 16} }}, + {{ {4, 16}, {4, 16} }}, + {{ {2, 16}, {2, 16} }}, + {{ {1, 16}, {1, 16} }} +}}; + +struct Tile { + std::array tile_shape = {constants::TILE_HEIGHT, constants::TILE_WIDTH}; + std::array face_shape = {constants::FACE_HEIGHT, constants::FACE_WIDTH}; + uint32_t tile_hw = constants::TILE_HW; + uint32_t face_hw = constants::FACE_HW; + uint32_t num_faces = constants::TILE_HW / constants::FACE_HW; + uint32_t partial_face = 0; + uint32_t narrow_tile = 0; + + Tile(const std::array& tile_shape = {constants::TILE_HEIGHT, constants::TILE_WIDTH}) : tile_shape(tile_shape) { + auto it = std::find_if(TILE_FACE_HW_CHOICES.begin(), TILE_FACE_HW_CHOICES.end(), + [this, &tile_shape](const auto& pair) { + if (pair[0] == tile_shape) { + this->face_shape = pair[1]; + return true; + } + return false; + }); + if (it == TILE_FACE_HW_CHOICES.end()) { + TT_THROW("Tile size is not valid for our hardware"); + } + + tile_hw = tile_shape[0] * tile_shape[1]; + face_hw = face_shape[0] * face_shape[1]; + num_faces = tile_hw / face_hw; + partial_face = (uint32_t)(tile_shape[0] < constants::TILE_HEIGHT); + narrow_tile = (uint32_t)(tile_shape[1] < constants::TILE_WIDTH); + } + + // Getter methods + const uint32_t get_num_faces() const { return num_faces; } + const uint32_t get_tile_hw() const { return tile_hw; } + const uint32_t get_face_hw() const { return face_hw; } + const uint32_t get_partial_face() const { return partial_face; } + const uint32_t get_narrow_tile() const { return narrow_tile; } + const std::array get_tile_shape() const { return tile_shape; } + const std::array get_face_shape() const { return face_shape; } + + const uint32_t get_tile_size(const DataFormat& format) const { + switch (format) { + case DataFormat::Bfp2: + case DataFormat::Bfp2_b: return (tile_hw / 4) + (16 * num_faces); + case DataFormat::Bfp4: + case DataFormat::Bfp4_b: return (tile_hw / 2) + (16 * num_faces); + case DataFormat::Bfp8: + case DataFormat::Bfp8_b: return tile_hw + (16 * num_faces); + case DataFormat::Float16: + case DataFormat::Float16_b: return (tile_hw * 2); + case DataFormat::Float32: return (tile_hw * 4); + case DataFormat::Tf32: throw std::invalid_argument("TF32 unsupported atm"); + case DataFormat::Int8: return tile_hw; + case DataFormat::Lf8: return tile_hw; + case DataFormat::UInt8: return tile_hw; + case DataFormat::UInt16: return (tile_hw * 2); + case DataFormat::UInt32: return (tile_hw * 4); + case DataFormat::RawUInt8: return tile_hw; + case DataFormat::RawUInt16: return (tile_hw * 2); + case DataFormat::Int32: return (tile_hw * 4); + case DataFormat::RawUInt32: return (tile_hw * 4); + case DataFormat::Invalid: throw std::invalid_argument("Invalid data format"); + default: throw std::invalid_argument("Unknown format"); + } + } + + // operators + bool operator==(const Tile& other) const { + return tile_shape == other.tile_shape && face_shape == other.face_shape; + } + + static constexpr auto attribute_names = std::forward_as_tuple("tile_shape", "face_shape", "num_faces"); + const auto attribute_values() const { return std::forward_as_tuple(tile_shape, face_shape, num_faces); } +}; + +} // namespace tt_metal + +} // namespace tt diff --git a/tt_metal/include/compute_kernel_api/bcast.h b/tt_metal/include/compute_kernel_api/bcast.h index ba3c5d7dd97..fe66b90b204 100644 --- a/tt_metal/include/compute_kernel_api/bcast.h +++ b/tt_metal/include/compute_kernel_api/bcast.h @@ -197,7 +197,7 @@ ALWI void mul_tiles_bcast(uint32_t icb0, uint32_t icb1, uint32_t itile0, uint32_ */ ALWI void add_bcast_rows_init_short(uint32_t icb0 = 0, uint32_t icb1 = 1) { - MATH(( llk_math_eltwise_binary_init() )); + MATH(( llk_math_eltwise_binary_init_with_operands(icb0, icb1) )); UNPACK(( llk_unpack_AB_init(icb0, icb1) )); } diff --git a/tt_metal/jit_build/genfiles.cpp b/tt_metal/jit_build/genfiles.cpp index 3c55740ed90..2d0253fac41 100644 --- a/tt_metal/jit_build/genfiles.cpp +++ b/tt_metal/jit_build/genfiles.cpp @@ -331,6 +331,52 @@ static void generate_data_format_descriptors(JitBuildOptions& options, const tt: emit_pack_data_formats(pack_data_format_descs, pack_src_formats_all_cbs, pack_dst_formats_all_cbs); } +static std::string array_to_string(const uint32_t arr[]) { + std::string formats_string = ""; + for (int i = 0; i < NUM_CIRCULAR_BUFFERS; i++) { + formats_string += to_string((int)arr[i]) + ","; + } + return formats_string; +} + +static void emit_unpack_tile_dims(std::string unpack_tile_dims_descs, tt_hlk_desc& desc) { + ofstream file_stream; + file_stream.open(unpack_tile_dims_descs); + file_stream << create_formats_array_string("constexpr uint8_t", "unpack_tile_num_faces", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_num_faces_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "unpack_partial_face", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_partial_face_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "unpack_tile_face_r_dim", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_face_r_dim_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "unpack_narrow_tile", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_narrow_tile_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "unpack_tile_r_dim", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_tile_r_dim_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "unpack_tile_c_dim", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_tile_c_dim_arr)); + file_stream << create_formats_array_string("constexpr uint16_t", "unpack_tile_size", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_tile_size_arr)); + file_stream.close(); +} + +static void emit_pack_tile_dims(std::string pack_tile_dims_descs, tt_hlk_desc& desc) { + ofstream file_stream; + file_stream.open(pack_tile_dims_descs); + file_stream << create_formats_array_string("constexpr uint8_t", "pack_tile_num_faces", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_num_faces_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "pack_partial_face", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_partial_face_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "pack_tile_face_r_dim", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_face_r_dim_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "pack_narrow_tile", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_narrow_tile_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "pack_tile_r_dim", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_tile_r_dim_arr)); + file_stream << create_formats_array_string("constexpr uint8_t", "pack_tile_c_dim", NUM_CIRCULAR_BUFFERS, array_to_string(desc.buf_tile_c_dim_arr)); + file_stream.close(); +} + +static void generate_tile_dims_descriptors(JitBuildOptions& options, const tt::ARCH arch) { + string out_file_name_base = "chlkc_"; + string out_file_name_suffix = "_tile_dims.h"; + string unpack_tile_dims_descs = options.path + out_file_name_base + "unpack" + out_file_name_suffix; + string pack_tile_dims_descs = options.path + out_file_name_base + "pack" + out_file_name_suffix; + + // assuming all cores within a op have the same desc + tt_hlk_desc& desc = options.hlk_desc; + + emit_unpack_tile_dims(unpack_tile_dims_descs, desc); + emit_pack_tile_dims(pack_tile_dims_descs, desc); +} + static void generate_dst_accum_mode_descriptor(JitBuildOptions& options) { string dst_accum_format_descriptor = options.path + "chlkc_dst_accum_mode.h"; @@ -381,10 +427,12 @@ void jit_build_genfiles_descriptors(const JitBuildEnv& env, fs::create_directories(options.path); try { std::thread td( [&]() { generate_data_format_descriptors(options, env.get_arch()); } ); + std::thread tt( [&]() { generate_tile_dims_descriptors(options, env.get_arch()); } ); std::thread tm( [&]() { generate_math_fidelity_descriptor(options); } ); std::thread ta( [&]() { generate_math_approx_mode_descriptor(options); } ); std::thread tf( [&]() { generate_dst_accum_mode_descriptor(options); } ); td.join(); + tt.join(); tm.join(); ta.join(); tf.join(); diff --git a/tt_metal/jit_build/hlk_desc.hpp b/tt_metal/jit_build/hlk_desc.hpp index 5cd3c1cc265..9fa6d39cd4d 100644 --- a/tt_metal/jit_build/hlk_desc.hpp +++ b/tt_metal/jit_build/hlk_desc.hpp @@ -34,6 +34,13 @@ class tt_hlk_desc DataFormat param_buf_dataformat_arr[8]; DataFormat output_buf_dataformat_arr[8]; DataFormat intermediate_buf_dataformat_arr[8]; + uint32_t buf_num_faces_arr[32]; + uint32_t buf_partial_face_arr[32]; + uint32_t buf_face_r_dim_arr[32]; + uint32_t buf_narrow_tile_arr[32]; + uint32_t buf_tile_r_dim_arr[32]; + uint32_t buf_tile_c_dim_arr[32]; + uint32_t buf_tile_size_arr[32]; tt_hlk_desc() { @@ -50,6 +57,17 @@ class tt_hlk_desc output_buf_dataformat_arr[i] = DataFormat::Invalid; intermediate_buf_dataformat_arr[i] = DataFormat::Invalid; } + + for (int i = 0; i < 32; ++i) + { + buf_num_faces_arr[i] = constants::TILE_HW / constants::FACE_HW; + buf_partial_face_arr[i] = 0; + buf_face_r_dim_arr[i] = constants::FACE_HEIGHT; + buf_narrow_tile_arr[i] = 0; + buf_tile_r_dim_arr[i] = constants::TILE_HEIGHT; + buf_tile_c_dim_arr[i] = constants::TILE_WIDTH; + buf_tile_size_arr[i] = constants::BFLOAT8_B_TILE_HW; + } } tt_hlk_desc(tt_hlk_desc &in) @@ -62,6 +80,17 @@ class tt_hlk_desc intermediate_buf_dataformat_arr[i] = in.intermediate_buf_dataformat_arr[i]; } + for (int i = 0; i < 32; ++i) + { + buf_num_faces_arr[i] = in.buf_num_faces_arr[i]; + buf_partial_face_arr[i] = in.buf_partial_face_arr[i]; + buf_face_r_dim_arr[i] = in.buf_face_r_dim_arr[i]; + buf_narrow_tile_arr[i] = in.buf_narrow_tile_arr[i]; + buf_tile_r_dim_arr[i] = in.buf_tile_r_dim_arr[i]; + buf_tile_c_dim_arr[i] = in.buf_tile_c_dim_arr[i]; + buf_tile_size_arr[i] = in.buf_tile_size_arr[i]; + } + math_fidelity = in.math_fidelity; hlk_file_name = in.hlk_file_name; hlk_args = in.hlk_args; @@ -109,6 +138,76 @@ class tt_hlk_desc intermediate_buf_dataformat_arr[buf_idx] = data_format; } + uint32_t get_buf_num_faces(int buf_idx) const + { + return buf_num_faces_arr[buf_idx]; + } + + void set_buf_num_faces(int buf_idx, uint32_t num_faces) + { + buf_num_faces_arr[buf_idx] = num_faces; + } + + uint32_t get_buf_partial_face(int buf_idx) const + { + return buf_partial_face_arr[buf_idx]; + } + + void set_buf_partial_face(int buf_idx, uint32_t partial_face) + { + buf_partial_face_arr[buf_idx] = partial_face; + } + + uint32_t get_buf_face_r_dim(int buf_idx) const + { + return buf_face_r_dim_arr[buf_idx]; + } + + void set_buf_face_r_dim(int buf_idx, uint32_t face_r_dim) + { + buf_face_r_dim_arr[buf_idx] = face_r_dim; + } + + uint32_t get_buf_narrow_tile(int buf_idx) const + { + return buf_narrow_tile_arr[buf_idx]; + } + + void set_buf_narrow_tile(int buf_idx, uint32_t narrow_tile) + { + buf_narrow_tile_arr[buf_idx] = narrow_tile; + } + + uint32_t get_buf_tile_r_dim(int buf_idx) const + { + return buf_tile_r_dim_arr[buf_idx]; + } + + void set_buf_tile_r_dim(int buf_idx, uint32_t tile_r_dim) + { + buf_tile_r_dim_arr[buf_idx] = tile_r_dim; + } + + uint32_t get_buf_tile_c_dim(int buf_idx) const + { + return buf_tile_c_dim_arr[buf_idx]; + } + + void set_buf_tile_c_dim(int buf_idx, uint32_t tile_c_dim) + { + buf_tile_c_dim_arr[buf_idx] = tile_c_dim; + } + + uint32_t get_buf_tile_size(int buf_idx) const + { + return buf_tile_size_arr[buf_idx]; + } + + void set_buf_tile_size(int buf_idx, uint32_t tile_size) + { + buf_tile_size_arr[buf_idx] = tile_size; + } + void set_hlk_args(void* args, size_t size) { hlk_args = args; @@ -203,6 +302,11 @@ struct std::hash } tt::utils::hash_combine(hash_value, hash{}(obj.get_hlk_math_fidelity())); tt::utils::hash_combine(hash_value, hash{}(obj.get_hlk_math_approx_mode())); + for (int i = 0; i < 32; i++) + { + tt::utils::hash_combine(hash_value, hash{}(obj.get_buf_tile_r_dim(i))); + tt::utils::hash_combine(hash_value, hash{}(obj.get_buf_tile_c_dim(i))); + } // Get hash for hlk_args here void *hlk_args = obj.get_hlk_args(); diff --git a/tt_metal/jit_build/settings.cpp b/tt_metal/jit_build/settings.cpp index 14b6e9cb768..ed0e1e04503 100644 --- a/tt_metal/jit_build/settings.cpp +++ b/tt_metal/jit_build/settings.cpp @@ -47,6 +47,19 @@ namespace tt::tt_metal set_hlk_operand_dataformat_all_cores((HlkOperand)cb_id, data_format); } + void JitBuildOptions::set_cb_tile_dims_all_cores(CB cb_id, uint32_t num_faces, uint32_t partial_face, uint32_t face_r_dim, uint32_t narrow_tile, uint32_t tile_r_dim, uint32_t tile_c_dim) { + hlk_desc.set_buf_num_faces((int)cb_id, num_faces); + hlk_desc.set_buf_partial_face((int)cb_id, partial_face); + hlk_desc.set_buf_face_r_dim((int)cb_id, face_r_dim); + hlk_desc.set_buf_narrow_tile((int)cb_id, narrow_tile); + hlk_desc.set_buf_tile_r_dim((int)cb_id, tile_r_dim); + hlk_desc.set_buf_tile_c_dim((int)cb_id, tile_c_dim); + } + + void JitBuildOptions::set_cb_tile_size_all_cores(CB cb_id, uint32_t tile_size) { + hlk_desc.set_buf_tile_size((int)cb_id, tile_size); + } + void JitBuildOptions::set_hlk_operand_dataformat_all_cores(HlkOperand op_id, DataFormat data_format) { static_assert(HlkOperand::in7 == int(HlkOperand::param0)-1); diff --git a/tt_metal/jit_build/settings.hpp b/tt_metal/jit_build/settings.hpp index ab0ecd9ac11..36559c474c2 100644 --- a/tt_metal/jit_build/settings.hpp +++ b/tt_metal/jit_build/settings.hpp @@ -52,6 +52,8 @@ class JitBuildOptions { void set_hlk_args_all_cores(void* args, size_t size); void set_cb_dataformat_all_cores(CB cb_id, DataFormat data_format); + void set_cb_tile_dims_all_cores(CB cb_id, uint32_t num_faces, uint32_t partial_face, uint32_t face_r_dim, uint32_t narrow_tile, uint32_t tile_r_dim, uint32_t tile_c_dim); + void set_cb_tile_size_all_cores(CB cb_id, uint32_t tile_size); // old API name void set_hlk_operand_dataformat_all_cores(HlkOperand op_id, DataFormat data_format); }; diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 2ba04930039..0eb95b3e378 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -64,16 +64,16 @@ void log_external_operation( #endif template -Tensor create_owned_tensor(T* data_ptr, size_t num_elements, std::vector& shape, DataType data_type, Layout layout) +Tensor create_owned_tensor(T* data_ptr, size_t num_elements, std::vector& shape, DataType data_type, Layout layout, const std::optional& optional_tile = std::nullopt) { auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); auto storage = OwnedStorage{std::move(buffer)}; - return Tensor(std::move(storage), shape, data_type, layout); + return Tensor(std::move(storage), shape, data_type, layout, optional_tile); } Tensor convert_torch_tensor_to_tt_tensor( - const py::handle &torch_tensor, std::optional optional_data_type = std::nullopt, bool enable_borrow = true) { + const py::handle &torch_tensor, std::optional optional_data_type = std::nullopt, const std::optional& optional_tile = std::nullopt, bool enable_borrow = true) { py::object torch = py::module_::import("torch"); if (not py::isinstance(torch_tensor, torch.attr("Tensor"))) { TT_THROW("The argument must be of type torch.Tensor!"); @@ -163,9 +163,9 @@ Tensor convert_torch_tensor_to_tt_tensor( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR); + return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); } } case DataType::UINT16: { @@ -173,9 +173,9 @@ Tensor convert_torch_tensor_to_tt_tensor( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR); + return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); } } case DataType::INT32: { @@ -183,9 +183,9 @@ Tensor convert_torch_tensor_to_tt_tensor( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR); + return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); } } case DataType::UINT32: { @@ -193,9 +193,9 @@ Tensor convert_torch_tensor_to_tt_tensor( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR); + return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); } } case DataType::FLOAT32: { @@ -203,9 +203,9 @@ Tensor convert_torch_tensor_to_tt_tensor( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR); + return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); } } case DataType::BFLOAT16: { @@ -213,9 +213,9 @@ Tensor convert_torch_tensor_to_tt_tensor( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR); + return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); } } case DataType::BFLOAT8_B: { @@ -225,7 +225,7 @@ Tensor convert_torch_tensor_to_tt_tensor( auto buffer = owned_buffer::create(std::move(uint32_vector)); auto storage = OwnedStorage{std::move(buffer)}; // TODO(arakhmati): should it be Layout::TILE? - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } case DataType::BFLOAT4_B: { auto data_ptr = reinterpret_cast(torch_data_ptr); @@ -234,7 +234,7 @@ Tensor convert_torch_tensor_to_tt_tensor( auto buffer = owned_buffer::create(std::move(uint32_vector)); auto storage = OwnedStorage{std::move(buffer)}; // TODO(arakhmati): should it be Layout::TILE? - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } default: { TT_THROW("Unsupported DataType: {}", data_type); @@ -244,7 +244,7 @@ Tensor convert_torch_tensor_to_tt_tensor( } Tensor convert_numpy_tensor_to_tt_tensor( - const py::handle &np_tensor, std::optional optional_data_type = std::nullopt) { + const py::handle &np_tensor, std::optional optional_data_type = std::nullopt, const std::optional& optional_tile = std::nullopt) { py::object np = py::module_::import("numpy"); if (not py::isinstance(np_tensor, np.attr("ndarray"))) { TT_THROW("The tensor must be of type numpy.ndarray!"); @@ -333,31 +333,31 @@ Tensor convert_numpy_tensor_to_tt_tensor( auto data_ptr = reinterpret_cast(np_data_ptr); auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } case DataType::UINT16: { auto data_ptr = reinterpret_cast(np_data_ptr); auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } case DataType::INT32: { auto data_ptr = reinterpret_cast(np_data_ptr); auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } case DataType::UINT32: { auto data_ptr = reinterpret_cast(np_data_ptr); auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } case DataType::FLOAT32: { auto data_ptr = reinterpret_cast(np_data_ptr); auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } /* case DataType::BFLOAT16: { @@ -374,7 +374,7 @@ Tensor convert_numpy_tensor_to_tt_tensor( auto buffer = owned_buffer::create(std::move(uint32_vector)); auto storage = OwnedStorage{std::move(buffer)}; // TODO(arakhmati): should it be Layout::TILE? - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } case DataType::BFLOAT4_B: { auto data_ptr = reinterpret_cast(np_data_ptr); @@ -383,7 +383,7 @@ Tensor convert_numpy_tensor_to_tt_tensor( auto buffer = owned_buffer::create(std::move(uint32_vector)); auto storage = OwnedStorage{std::move(buffer)}; // TODO(arakhmati): should it be Layout::TILE? - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR); + return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); } default: { TT_THROW("Unsupported DataType: {}", data_type); @@ -393,17 +393,17 @@ Tensor convert_numpy_tensor_to_tt_tensor( } Tensor convert_python_tensor_to_tt_tensor( - const py::handle &tensor, std::optional optional_data_type = std::nullopt, bool enable_borrow = true) { + const py::handle &tensor, std::optional optional_data_type = std::nullopt, const std::optional& optional_tile = std::nullopt, bool enable_borrow = true) { GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_python_tensor_to_tt_tensor", tensor, optional_data_type, enable_borrow); py::object torch = py::module_::import("torch"); py::object np = py::module_::import("numpy"); if (py::isinstance(tensor, torch.attr("Tensor"))) { - auto output = convert_torch_tensor_to_tt_tensor(tensor, optional_data_type, enable_borrow); + auto output = convert_torch_tensor_to_tt_tensor(tensor, optional_data_type, optional_tile, enable_borrow); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; } else if (py::isinstance(tensor, np.attr("ndarray"))) { - auto output = convert_numpy_tensor_to_tt_tensor(tensor, optional_data_type); + auto output = convert_numpy_tensor_to_tt_tensor(tensor, optional_data_type, optional_tile); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -412,11 +412,11 @@ Tensor convert_python_tensor_to_tt_tensor( } } -Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optional data_type, const std::unordered_map& strategy) { +Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optional data_type, const std::optional tile, const std::unordered_map& strategy) { GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_python_tensors_to_tt_tensors", tensor_shards, data_type, strategy); std::vector tt_shards; for (const auto &shard : tensor_shards) { - tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(shard, data_type, false)); + tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(shard, data_type, tile, false)); } std::vector host_owned_buffers; std::vector host_owned_shapes; @@ -428,7 +428,7 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona auto distributed_tensor_config = get_distributed_tensor_config(strategy); auto storage = MultiDeviceHostStorage{distributed_tensor_config, std::move(host_owned_buffers), host_owned_shapes}; - auto output = Tensor(std::move(storage), tt_shards.at(0).get_legacy_shape(), tt_shards.at(0).get_dtype(), Layout::ROW_MAJOR); + auto output = Tensor(std::move(storage), tt_shards.at(0).get_legacy_shape(), tt_shards.at(0).get_dtype(), Layout::ROW_MAJOR, tt_shards.at(0).get_tile()); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -726,10 +726,16 @@ void pytensor_module(py::module &m_tensor) { py::init<>([](std::vector &&data, const std::array &shape, DataType data_type, - Layout layout) { + Layout layout, + const std::optional &tile) { auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); - return Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); + return Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); }), + py::arg("data"), + py::arg("shape"), + py::arg("data_type"), + py::arg("layout"), + py::arg("tile") = std::nullopt, py::return_value_policy::move, R"doc( +---------------+---------------+ @@ -761,12 +767,19 @@ void pytensor_module(py::module &m_tensor) { const std::array &shape, DataType data_type, Layout layout, - Device *device) { + Device *device, + const std::optional &tile) { auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); - auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); + auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); return tensor.to(device, MemoryConfig{}); }), py::keep_alive<1, 6>(), + py::arg("data"), + py::arg("shape"), + py::arg("data_type"), + py::arg("layout"), + py::arg("device") = std::nullopt, + py::arg("tile") = std::nullopt, py::return_value_policy::move, R"doc( +---------------+---------------+ @@ -808,12 +821,20 @@ void pytensor_module(py::module &m_tensor) { DataType data_type, Layout layout, Device *device, - const MemoryConfig &memory_config) { + const MemoryConfig &memory_config, + const std::optional &tile) { auto owned_buffer = detail::create_owned_buffer_from_vector_of_floats(std::move(data), data_type); - auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); + auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); return tensor.to(device, memory_config); }), - py::keep_alive<1, 6>(), + py::keep_alive<1, 7>(), + py::arg("data"), + py::arg("shape"), + py::arg("data_type"), + py::arg("layout"), + py::arg("device") = std::nullopt, + py::arg("memory_config"), + py::arg("tile") = std::nullopt, py::return_value_policy::move, R"doc( +---------------+---------------+ @@ -856,15 +877,17 @@ void pytensor_module(py::module &m_tensor) { .def( py::init<>([](const py::object &tensor, std::optional data_type, - const std::unordered_map &strategy) { + const std::unordered_map &strategy, + const std::optional &tile) { if (py::isinstance(tensor)) { - return detail::convert_python_tensors_to_tt_tensors(tensor, data_type, strategy); + return detail::convert_python_tensors_to_tt_tensors(tensor, data_type, tile, strategy); } - return detail::convert_python_tensor_to_tt_tensor(tensor, data_type); + return detail::convert_python_tensor_to_tt_tensor(tensor, data_type, tile); }), py::arg("tensor"), py::arg("data_type") = std::nullopt, py::arg("strategy") = std::unordered_map(), + py::arg("tile") = std::nullopt, py::return_value_policy::move, R"doc( +--------------+------------------------+ @@ -887,8 +910,9 @@ void pytensor_module(py::module &m_tensor) { std::optional data_type, Device *device, Layout layout, - const MemoryConfig &mem_config) { - auto tensor = detail::convert_python_tensor_to_tt_tensor(python_tensor, data_type); + const MemoryConfig &mem_config, + const std::optional &tile) { + auto tensor = detail::convert_python_tensor_to_tt_tensor(python_tensor, data_type, tile); auto layout_tensor = tensor.to(layout); return layout_tensor.to(device, mem_config); }), @@ -897,6 +921,7 @@ void pytensor_module(py::module &m_tensor) { py::arg("device").noconvert(), py::arg("layout").noconvert(), py::arg("mem_config").noconvert(), + py::arg("tile") = std::nullopt, py::return_value_policy::move, R"doc( +--------------+------------------------+ @@ -924,6 +949,7 @@ void pytensor_module(py::module &m_tensor) { .def_property_readonly("shape", [](const Tensor &self) { return self.get_shape(); }) .def_property_readonly("dtype", [](const Tensor &self) { return self.get_dtype(); }) .def_property_readonly("layout", [](const Tensor &self) { return self.get_layout(); }) + .def_property_readonly("tile", [](const Tensor &self) { return self.get_tile(); }) .def( "deallocate", [](Tensor &self, bool force) { return self.deallocate(force); }, @@ -1530,6 +1556,15 @@ void pytensor_module(py::module &m_tensor) { layout = tt_tensor.get_layout() + )doc") + .def( + "get_tile", [](const Tensor &self) { return self.get_tile(); }, R"doc( + Get tile dims of TT Tensor. + + .. code-block:: python + + tile = tt_tensor.get_tile() + )doc") .def( "memory_config", [](const Tensor &self) { return self.memory_config(); }, R"doc( diff --git a/ttnn/cpp/pybind11/tensor.cpp b/ttnn/cpp/pybind11/tensor.cpp index 1925da8ec9d..b3f6d148040 100644 --- a/ttnn/cpp/pybind11/tensor.cpp +++ b/ttnn/cpp/pybind11/tensor.cpp @@ -81,6 +81,10 @@ void tensor_mem_config_module_types(py::module& m_tensor) { Class defining core coordinate )doc"); + py::class_(m_tensor, "Tile", R"doc( + Class defining tile dims + )doc"); + py::class_(m_tensor, "Shape", R"doc( Class defining tensor shape )doc"); @@ -128,6 +132,19 @@ void tensor_mem_config_module(py::module& m_tensor) { .def_readonly("y", &CoreCoord::y); py::implicitly_convertible, CoreCoord>(); + auto py_tile = static_cast>(m_tensor.attr("Tile")); + py_tile.def(py::init&>()) + .def(py::init<>([](const std::array& tile) { + return Tile{tile}; + })) + .def("__repr__", [](const Tile& self) { + return fmt::format("Tile with shape: [{}, {}]", self.get_tile_shape()[0], self.get_tile_shape()[1]); + }) + .def_readonly("tile_shape", &Tile::tile_shape) + .def_readonly("face_shape", &Tile::face_shape) + .def_readonly("num_faces", &Tile::num_faces); + py::implicitly_convertible, Tile>(); + auto py_shape = static_cast>(m_tensor.attr("Shape")); py_shape.def(py::init>()) .def( diff --git a/ttnn/cpp/ttnn/async_runtime.cpp b/ttnn/cpp/ttnn/async_runtime.cpp index 5b0a65318ea..2a8bc818df6 100644 --- a/ttnn/cpp/ttnn/async_runtime.cpp +++ b/ttnn/cpp/ttnn/async_runtime.cpp @@ -17,8 +17,9 @@ DeviceBuffer allocate_interleaved_buffer_on_device( const Shape& shape, DataType data_type, Layout layout, - const MemoryConfig& memory_config) { - uint32_t page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value); + const MemoryConfig& memory_config, + const std::optional& tile) { + uint32_t page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value, tile); return std::make_shared(device, buffer_size_bytes, page_size, memory_config.buffer_type); } @@ -34,14 +35,15 @@ DeviceBuffer allocate_sharded_buffer_on_device( DataType data_type, Layout layout, const ShardSpecBuffer& shard_params, - const MemoryConfig& memory_config) { + const MemoryConfig& memory_config, + const std::optional& tile) { tt::tt_metal::tensor_impl::validate_sharded_buffer_allocation( - shape.value, layout, data_type, shard_params, memory_config); + shape.value, layout, data_type, shard_params, memory_config, tile); const auto& page_shape = shard_params.page_shape; uint32_t size_of_element = tt::tt_metal::tensor_impl::element_size_bytes(data_type); uint32_t page_size = page_shape[0] * page_shape[1] * size_of_element; if (layout == Layout::TILE) { - page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value); + page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value, tile); } return std::make_shared( @@ -55,15 +57,16 @@ DeviceBuffer allocate_buffer_on_device( DataType data_type, Layout layout, const MemoryConfig& memory_config, - const std::optional& shard_spec) { + const std::optional& shard_spec, + const std::optional& tile) { if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { return allocate_interleaved_buffer_on_device( - buffer_size_bytes, device, shape, data_type, layout, memory_config); + buffer_size_bytes, device, shape, data_type, layout, memory_config, tile); } else if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::SINGLE_BANK) { return allocate_contiguous_buffer_on_device(buffer_size_bytes, device, memory_config); } else { return allocate_sharded_buffer_on_device( - buffer_size_bytes, device, shape, data_type, layout, shard_spec.value(), memory_config); + buffer_size_bytes, device, shape, data_type, layout, shard_spec.value(), memory_config, tile); } } diff --git a/ttnn/cpp/ttnn/async_runtime.hpp b/ttnn/cpp/ttnn/async_runtime.hpp index 0c001316932..672a214f210 100644 --- a/ttnn/cpp/ttnn/async_runtime.hpp +++ b/ttnn/cpp/ttnn/async_runtime.hpp @@ -12,7 +12,7 @@ namespace ttnn { using DeviceBuffer = std::shared_ptr; using queue_id = uint8_t; - DeviceBuffer allocate_buffer_on_device(size_t buffer_size_bytes, types::Device* device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, const std::optional& shard_spec = std::nullopt); + DeviceBuffer allocate_buffer_on_device(size_t buffer_size_bytes, types::Device* device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, const std::optional& shard_spec = std::nullopt, const std::optional& tile = std::nullopt); void write_buffer(queue_id cq_id, Tensor& dst, std::vector> src, const std::optional transfer_size = std::nullopt); diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index e8c6d156cef..62c185b2a2e 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -73,14 +73,15 @@ Tensor to_layout_impl( TT_THROW("ttnn::to_layout: Unsupported layout conversion from {} to {}!", tensor_arg.get_layout(), layout); } - const auto requires_padding_change = [](ttnn::Layout layout, const ttnn::Shape& shape) -> bool { + const auto requires_padding_change = [](ttnn::Tensor& tensor, ttnn::Layout layout, const ttnn::Shape& shape) -> bool { const auto intended_shape = shape; const auto padded_shape = shape.with_tile_padding(); if (layout == ttnn::ROW_MAJOR_LAYOUT and intended_shape != padded_shape) { return true; } else if ( - layout == ttnn::TILE_LAYOUT and (padded_shape.rank() < 2 or padded_shape[-1] % ttnn::TILE_SIZE != 0 or - padded_shape[-2] % ttnn::TILE_SIZE != 0)) { + auto tile = tensor.tile(); + layout == ttnn::TILE_LAYOUT and (padded_shape.rank() < 2 or padded_shape[-1] % tile.get_tile_shape()[1] != 0 or + padded_shape[-2] % tile.get_tile_shape()[0] != 0)) { return true; } else { return false; @@ -116,7 +117,7 @@ Tensor to_layout_impl( bool use_multicore_untilize = true; bool use_multicore_tilize = use_multicore_device_tilize(tensor, dtype); - if (not requires_padding_change(layout, tensor.get_shape())) { + if (not requires_padding_change(tensor, layout, tensor.get_shape())) { if (layout == ttnn::ROW_MAJOR_LAYOUT) { TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); return ttnn::untilize(tensor, output_memory_config, use_multicore_untilize); @@ -182,7 +183,7 @@ Tensor to_layout_impl( } } else { TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting layout on host!"); - if (not requires_padding_change(layout, tensor.get_shape())) { + if (not requires_padding_change(tensor, layout, tensor.get_shape())) { return device ? tensor.to(layout, device) : tensor.to(layout); } else if (layout == ttnn::ROW_MAJOR_LAYOUT) { tensor = device ? tensor.to(layout, device) : tensor.to(layout); diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp index d403e564c6e..0fe7e184d2c 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp @@ -149,7 +149,6 @@ void MAIN { const uint32_t total_tiles_per_col = get_arg_val(2); const uint32_t num_chunks_per_row = get_arg_val(3); - untilize_init(cb_a_in); binary_op_init_common(cb_a_in, cb_bx_in); // Fill initial hidden states diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp index d0dec2d76d2..a377d69278b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp @@ -81,7 +81,6 @@ void MAIN { constexpr uint32_t Wt = get_compile_time_arg_val(10); constexpr uint32_t half_Wt = get_compile_time_arg_val(11); - binary_op_init_common(in_cb, cos_cb); cb_wait_front(scalar_cb, onetile); diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp index 533d1176360..c0a22d0327b 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp @@ -15,6 +15,7 @@ #include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" + // Please update // tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation_copy.cpp // when making any changes to this file. @@ -156,6 +157,7 @@ void MAIN { cb_wait_front(in0_cb_id, in0_block_num_tiles); cb_wait_front(in1_cb_id, in1_block_num_tiles); + int in0_index_subblock_offset = 0; for (uint32_t in0_subblock = 0; in0_subblock < in0_num_subblocks; in0_subblock++) { int in1_index_subblock_offset = 0; @@ -196,6 +198,7 @@ void MAIN { in1_index += in1_per_core_w; // to stride down by 1 need to stride by in_per_core_w (should be // called in1_block_w) } + #endif // SKIP_COMPUTE if (last_out) { diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0.cpp index 166bb8e9c33..7b1904fd2ce 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0.cpp @@ -41,10 +41,11 @@ void kernel_main() { const uint32_t in0_single_tile_size_bytes = get_tile_size(cb_id_in0); const DataFormat in0_data_format = get_dataformat(cb_id_in0); + constexpr const uint32_t in0_tile_hw = get_tile_hw(cb_id_in0); uint32_t l1_write_addr_in0; - const InterleavedAddrGenFast s0 = { + const InterleavedAddrGenFast s0 = { .bank_base_address = in0_tensor_addr, .page_size = in0_single_tile_size_bytes, .data_format = in0_data_format}; for (uint32_t b = 0; b < batch; ++b) { diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp index 2e76f54632e..54dbe232292 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_padding.cpp @@ -82,8 +82,8 @@ void kernel_main() { } #else constexpr DataFormat in0_data_format = get_dataformat(cb_id_in0); - - const InterleavedAddrGenFast s0 = { + constexpr const uint32_t in0_tile_hw = get_tile_hw(cb_id_in0); + const InterleavedAddrGenFast s0 = { .bank_base_address = in0_tensor_addr, .page_size = in0_single_tile_size_bytes, .data_format = in0_data_format}; #endif diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp index f41cb351131..f611a09e90f 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp @@ -75,13 +75,14 @@ void kernel_main() { // WRITER // single-tile const uint32_t output_single_tile_size_bytes = get_tile_size(cb_id_out0); + constexpr const uint32_t output_tile_hw = get_tile_hw(cb_id_out0); const DataFormat output_data_format = get_dataformat(cb_id_out0); volatile tt_l1_ptr uint32_t* in1_mcast_receiver_semaphore_addr_ptr = reinterpret_cast(in1_mcast_receiver_semaphore_addr); // WRITER - const InterleavedAddrGenFast s = { + const InterleavedAddrGenFast s = { .bank_base_address = out_tensor_addr, .page_size = output_single_tile_size_bytes, .data_format = output_data_format}; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp index 41d09cc92b4..a370668322f 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp @@ -87,10 +87,11 @@ void kernel_main() { constexpr uint32_t cb_id_in3 = 3; constexpr uint32_t bias_single_tile_size_bytes = get_tile_size(cb_id_in3); constexpr DataFormat bias_data_format = get_dataformat(cb_id_in3); + constexpr const uint32_t in3_tile_hw = get_tile_hw(cb_id_in3); uint32_t l1_write_addr_in3; - const InterleavedAddrGenFast s3 = { + const InterleavedAddrGenFast s3 = { .bank_base_address = in3_tensor_addr, .page_size = bias_single_tile_size_bytes, .data_format = bias_data_format}; @@ -125,6 +126,7 @@ void kernel_main() { constexpr uint32_t cb_id_in1 = 1; constexpr uint32_t in1_single_tile_size_bytes = get_tile_size(cb_id_in1); + constexpr const uint32_t in1_tile_hw = get_tile_hw(cb_id_in1); constexpr uint32_t in1_block_size_bytes = in1_block_num_tiles * in1_single_tile_size_bytes; @@ -136,15 +138,16 @@ void kernel_main() { uint32_t l1_write_addr_in1; constexpr DataFormat in1_data_format = get_dataformat(cb_id_in1); - const InterleavedAddrGenFast s1 = { + const InterleavedAddrGenFast s1 = { .bank_base_address = in1_tensor_addr, .page_size = in1_single_tile_size_bytes, .data_format = in1_data_format}; #endif // WRITER constexpr uint32_t cb_id_out0 = 16; constexpr uint32_t output_single_tile_size_bytes = get_tile_size(cb_id_out0); + constexpr const uint32_t output_tile_hw = get_tile_hw(cb_id_out0); constexpr DataFormat output_data_format = get_dataformat(cb_id_out0); - const InterleavedAddrGenFast s = { + const InterleavedAddrGenFast s = { .bank_base_address = out_tensor_addr, .page_size = output_single_tile_size_bytes, .data_format = output_data_format}; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp index 40bc1d601ed..322de708588 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp @@ -6,6 +6,7 @@ #include "dataflow_api.h" + void kernel_main() { // in0/in1 common args const uint32_t num_blocks = get_arg_val(0); @@ -63,7 +64,8 @@ void kernel_main() { #else const uint32_t in1_single_tile_size_bytes = get_tile_size(cb_id_in1); const DataFormat in1_data_format = get_dataformat(cb_id_in1); - const InterleavedAddrGenFast s1 = { + constexpr const uint32_t in1_tile_hw = get_tile_hw(cb_id_in1); + const InterleavedAddrGenFast s1 = { .bank_base_address = in1_tensor_addr, .page_size = in1_single_tile_size_bytes, .data_format = in1_data_format}; uint32_t l1_write_addr_in1; #endif @@ -71,8 +73,9 @@ void kernel_main() { #ifndef OUT_SHARDED const uint32_t output_single_tile_size_bytes = get_tile_size(cb_id_out0); const DataFormat output_data_format = get_dataformat(cb_id_out0); + constexpr const uint32_t output_tile_hw = get_tile_hw(cb_id_out0); - const InterleavedAddrGenFast s = { + const InterleavedAddrGenFast s = { .bank_base_address = out_tensor_addr, .page_size = output_single_tile_size_bytes, .data_format = output_data_format}; diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 0d3913aca37..89fcec4c24f 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -308,14 +308,16 @@ MatmulMultiCoreReuseMultiCast1DProgramConfig get_mcast_1d_config( uint32_t K = input_tensor_a.get_legacy_shape()[-1]; uint32_t N = input_tensor_b.get_legacy_shape()[-1]; uint32_t per_core_M, per_core_N; + auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); if (mcast_in0) { - per_core_M = M / TILE_HEIGHT; - per_core_N = div_up(div_up(N, grid_size.x * grid_size.y), TILE_WIDTH); + per_core_M = M / in0_tile_shape[0]; + per_core_N = div_up(div_up(N, grid_size.x * grid_size.y), in1_tile_shape[1]); } else { - per_core_M = div_up(div_up(M, grid_size.x * grid_size.y), TILE_HEIGHT); - per_core_N = N / TILE_WIDTH; + per_core_M = div_up(div_up(M, grid_size.x * grid_size.y), in0_tile_shape[0]); + per_core_N = N / in1_tile_shape[1]; } - uint32_t in0_block_w = K / TILE_WIDTH % 2 == 0 ? 2 : 1; + uint32_t in0_block_w = K / in0_tile_shape[1] % 2 == 0 ? 2 : 1; bool per_core_N_equals_subblock_w_constraint = out_sharded && !mcast_in0; bool per_core_M_equals_subblock_h_constraint = out_sharded && mcast_in0; bool fp32_dest_acc_en = get_fp32_dest_acc_en(compute_kernel_config); @@ -349,11 +351,14 @@ inline MatmulProgramConfig create_simple_matmul_program_config( uint32_t batch_size_a = get_batch_size(ashape); uint32_t num_output_tiles = batch_size_a * ashape[-2] * bshape[-1] / TILE_HW; // Output M x N + auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + // Parameters for large matmul with reuse uint32_t B = batch_size_a; - uint32_t Mt = ashape[-2] / TILE_HEIGHT; - uint32_t Kt = ashape[-1] / TILE_WIDTH; - uint32_t Nt = bshape[-1] / TILE_WIDTH; + uint32_t Mt = ashape[-2] / in0_tile_shape[0]; + uint32_t Kt = ashape[-1] / in0_tile_shape[1]; + uint32_t Nt = bshape[-1] / in1_tile_shape[1]; uint32_t in0_block_w = 2; TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "input tensor needs to be on device"); @@ -577,6 +582,9 @@ MatmulProgramConfig get_matmul_program_config( // generic sharded output tensor creation auto grid_size = input_tensor_a.shard_spec().value().grid.bounding_box().grid_size(); + auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + // MCAST matmuls only support input_b in INTERLEAVED if (matmul) { TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); @@ -592,9 +600,9 @@ MatmulProgramConfig get_matmul_program_config( per_core_N_equals_subblock_w_constraint = true; } - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t K = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; auto shard_shape = input_tensor_a.shard_spec().value().shape; bool mcast_in0; @@ -605,10 +613,10 @@ MatmulProgramConfig get_matmul_program_config( mcast_in0 = true; per_core_M = M; per_core_N = div_up(N, input_tensor_a.shard_spec().value().grid.num_cores()); - in0_block_w = std::gcd(shard_shape[1] / TILE_WIDTH, K); + in0_block_w = std::gcd(shard_shape[1] / in0_tile_shape[1], K); } else if (input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { mcast_in0 = false; - per_core_M = shard_shape[0] / TILE_HEIGHT; + per_core_M = shard_shape[0] / in0_tile_shape[0]; per_core_N = N; // Only necessary if output is sharded; otherwise, can set this to be < N in0_block_w = K; } else { @@ -643,25 +651,25 @@ MatmulProgramConfig get_matmul_program_config( per_core_N_equals_subblock_w_constraint = true; } - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t K = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; auto shard_shape = input_tensor_a.shard_spec().value().shape; uint32_t virtual_x = transpose_mcast ? grid_size.y : grid_size.x; uint32_t virtual_y = transpose_mcast ? grid_size.x : grid_size.y; - bool cores_along_x_match_grid_size = virtual_x == (K / (shard_shape[1] / TILE_WIDTH)); - bool cores_along_y_match_grid_size = virtual_y == (M / (shard_shape[0] / TILE_HEIGHT)); + bool cores_along_x_match_grid_size = virtual_x == (K / (shard_shape[1] / in0_tile_shape[1])); + bool cores_along_y_match_grid_size = virtual_y == (M / (shard_shape[0] / in0_tile_shape[0])); TT_FATAL( - cores_along_y_match_grid_size || virtual_y == div_up(M, (shard_shape[0] / TILE_HEIGHT)), + cores_along_y_match_grid_size || virtual_y == div_up(M, (shard_shape[0] / in0_tile_shape[0])), "Num cores along y must match provided grid size!"); TT_FATAL( - cores_along_x_match_grid_size || virtual_x == div_up(K, (shard_shape[1] / TILE_WIDTH)), + cores_along_x_match_grid_size || virtual_x == div_up(K, (shard_shape[1] / in0_tile_shape[1])), "Num cores along x must match provided grid size!"); uint32_t per_core_M = div_up(M, virtual_y); uint32_t per_core_N = div_up(N, virtual_x); - uint32_t in0_block_w = cores_along_x_match_grid_size ? std::gcd(shard_shape[1] / TILE_WIDTH, K) : 1; + uint32_t in0_block_w = cores_along_x_match_grid_size ? std::gcd(shard_shape[1] / in0_tile_shape[1], K) : 1; auto subblock_hw = bmm_op_utils::get_matmul_subblock_params( per_core_M, per_core_N, false, per_core_N_equals_subblock_w_constraint, fp32_dest_acc_en); @@ -690,14 +698,14 @@ MatmulProgramConfig get_matmul_program_config( per_core_N_equals_subblock_w_constraint = true; } - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t K = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; auto in0_shard_shape = input_tensor_a.shard_spec().value().shape; - uint32_t per_core_M = in0_shard_shape[0] / TILE_HEIGHT; + uint32_t per_core_M = in0_shard_shape[0] / in0_tile_shape[0]; uint32_t per_core_N = N; - uint32_t in0_block_w = in0_shard_shape[1] / TILE_WIDTH; + uint32_t in0_block_w = in0_shard_shape[1] / in0_tile_shape[1]; auto subblock_hw = bmm_op_utils::get_matmul_subblock_params( per_core_M, per_core_N, false, per_core_N_equals_subblock_w_constraint, fp32_dest_acc_en); @@ -902,7 +910,10 @@ Matmul create_matmul_struct( parameters.untilize_out, parameters.user_core_coord, parameters.user_fused_activation, - parameters.user_run_batched}; + parameters.user_run_batched, + parameters.transpose_a, + parameters.transpose_b, + parameters.output_tile}; } Tensor matmul( @@ -951,6 +962,21 @@ void Matmul::validate( const auto& input_tensor_b = input_tensors.at(1); const auto& a_shape = input_tensor_a.get_shape(); const auto& b_shape = input_tensor_b.get_shape(); + auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + + if (input_tensor_a.device()->arch() == tt::ARCH::GRAYSKULL) { + TT_FATAL( + (input_tensor_a.get_tile().get_tile_shape()[1] == TILE_WIDTH && input_tensor_a.get_tile().get_tile_shape()[0] == TILE_HEIGHT), + "Grayskull does not support tiny tile"); + TT_FATAL( + (input_tensor_b.get_tile().get_tile_shape()[1] == TILE_WIDTH && input_tensor_b.get_tile().get_tile_shape()[0] == TILE_HEIGHT), + "Grayskull does not support tiny tile"); + } + + TT_FATAL( + (input_tensor_a.get_tile().get_tile_shape()[1] == TILE_WIDTH && input_tensor_b.get_tile().get_tile_shape()[0] == TILE_WIDTH), + "Input tile dims must have inner dim equal to 32 due to llk constraints"); TT_FATAL( (input_tensor_a.get_layout() == Layout::TILE && input_tensor_b.get_layout() == Layout::TILE), @@ -990,13 +1016,17 @@ void Matmul::validate( TT_FATAL(optional_input_tensors.size() == 1, "Error"); const auto& optional_bias = optional_input_tensors.at(0); if (optional_bias.has_value()) { + TT_FATAL( + (optional_bias->get_tile().get_tile_shape()[0] == input_tensor_a.get_tile().get_tile_shape()[0] && + optional_bias->get_tile().get_tile_shape()[1] == input_tensor_b.get_tile().get_tile_shape()[1]), + "Input tile dims must have inner dim equal to 32 due to llk constraints"); const auto& bias = optional_bias.value(); TT_FATAL(bias.get_layout() == Layout::TILE, "Unsupported input layout"); const auto& bias_shape = bias.get_shape(); uint32_t bias_batch_size = get_batch_size(bias_shape); TT_FATAL(bias_batch_size == 1, "Unsupported bias shape: batch size not equal to 1."); TT_FATAL( - bias_shape.with_tile_padding()[-2] == TILE_HEIGHT, "Unsupported bias shape: second last dimension not equal to tile height"); + bias_shape.with_tile_padding()[-2] == in0_tile_shape[0], "Unsupported bias shape: second last dimension not equal to tile height"); TT_FATAL( bias_shape.with_tile_padding()[-1] == b_shape.with_tile_padding()[-1], "Unsupported bias shape: last dimension not equal to second input's last dimension.", bias_shape.with_tile_padding()[-1], b_shape.with_tile_padding()[-1]); @@ -1010,7 +1040,7 @@ void Matmul::validate( } std::visit( - [input_tensor_a, input_tensor_b, this](const auto& program_config) { + [input_tensor_a, input_tensor_b, in0_tile_shape, in1_tile_shape, this](const auto& program_config) { using ProgramConfigType = std::decay_t; // TODO: For 1D and 2D mcasts, we don't check if tensor is single core or single row/col // We can uplift these variants to skip mcasting to support single core (1D) or single row/col (2D) @@ -1027,27 +1057,26 @@ void Matmul::validate( TT_FATAL(input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, "Error"); uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] - : input_tensor_a.get_legacy_shape()[-2]) / - TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t K = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH; + : input_tensor_a.get_legacy_shape()[-2]) / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; + uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; auto shard_shape = input_tensor_a.shard_spec().value().shape; // No padding TT_FATAL(M == per_core_M, "Error"); - TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT), "Error"); + TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); TT_FATAL(K % program_config.in0_block_w == 0, "Error"); - TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0, "Error"); + TT_FATAL((shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Error"); } if (this->output_mem_config.is_sharded()) { TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Error"); uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] : input_tensor_a.get_legacy_shape()[-2]) / - TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; @@ -1068,24 +1097,22 @@ void Matmul::validate( TT_FATAL(input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, "Error"); uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] - : input_tensor_a.get_legacy_shape()[-2]) / - TILE_HEIGHT; - uint32_t K = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH; + : input_tensor_a.get_legacy_shape()[-2]) / in0_tile_shape[0]; + uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; auto shard_shape = input_tensor_a.shard_spec().value().shape; TT_FATAL(div_up(M, per_core_M) == input_tensor_a.shard_spec().value().grid.num_cores(), "Error"); - TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT), "Error"); + TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); TT_FATAL(K % program_config.in0_block_w == 0, "Error"); - TT_FATAL(K == (shard_shape[1] / TILE_WIDTH), "Error"); + TT_FATAL(K == (shard_shape[1] / in0_tile_shape[1]), "Error"); } if (this->output_mem_config.is_sharded()) { TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] - : input_tensor_a.get_legacy_shape()[-2]) / - TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + : input_tensor_a.get_legacy_shape()[-2]) / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; @@ -1103,27 +1130,27 @@ void Matmul::validate( TT_FATAL(input_tensor_a.memory_config().buffer_type == this->output_mem_config.buffer_type, "Error"); TT_FATAL(input_tensor_a.memory_config().memory_layout == this->output_mem_config.memory_layout, "Error"); TT_FATAL(input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, "Error"); - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t K = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; + uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; auto shard_shape = input_tensor_a.shard_spec().value().shape; // No padding TT_FATAL(M == per_core_M, "Error"); - TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT), "Error"); + TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); TT_FATAL(K % program_config.in0_block_w == 0, "Error"); - TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0, "Error"); + TT_FATAL((shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Error"); // tensor in1 TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Error"); } else if constexpr (std::is_same_v) { if (input_tensor_a.memory_config().is_sharded()) { auto tensor_a_memory_layout = input_tensor_a.memory_config().memory_layout; - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t K = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t K = input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; auto shard_shape = input_tensor_a.shard_spec().value().shape; @@ -1148,14 +1175,14 @@ void Matmul::validate( } else if (tensor_a_memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { TT_FATAL(!program_config.transpose_mcast, "Error"); TT_FATAL(K == program_config.in0_block_w, "Error"); - TT_FATAL(program_config.in0_block_w == (shard_shape[1] / TILE_WIDTH), "Error"); + TT_FATAL(program_config.in0_block_w == (shard_shape[1] / in0_tile_shape[1]), "Error"); TT_FATAL( input_tensor_a.shard_spec()->grid.bounding_box().start_coord.x == input_tensor_a.shard_spec()->grid.bounding_box().end_coord.x, "Error"); } - TT_FATAL(per_core_M == (shard_shape[0] / TILE_HEIGHT), "Error"); - TT_FATAL((shard_shape[1] / TILE_WIDTH) % program_config.in0_block_w == 0, "Error"); + TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); + TT_FATAL((shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Error"); } if (input_tensor_b.memory_config().is_sharded()) { @@ -1164,7 +1191,7 @@ void Matmul::validate( TT_FATAL(tensor_b_memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Error"); if (input_tensor_b.buffer()->buffer_type() != tt_metal::BufferType::DRAM) { TT_FATAL( - program_config.per_core_N == (input_tensor_b.shard_spec().value().shape[1] / TILE_WIDTH), "Error"); + program_config.per_core_N == (input_tensor_b.shard_spec().value().shape[1] / in1_tile_shape[1]), "Error"); } TT_FATAL( input_tensor_b.shard_spec()->grid.bounding_box().start_coord.y == @@ -1173,17 +1200,17 @@ void Matmul::validate( if (this->output_mem_config.is_sharded()) { TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED, "Error"); - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; TT_FATAL(program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1, "Error"); } } else if constexpr (std::is_same_v) { - uint32_t M = input_tensor_a.get_legacy_shape()[-2] / TILE_HEIGHT; - uint32_t total_M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.get_legacy_shape()[-2] / in0_tile_shape[0]; + uint32_t total_M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; uint32_t K = input_tensor_a.get_legacy_shape()[-1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; @@ -1199,8 +1226,8 @@ void Matmul::validate( auto in0_shard_shape = input_tensor_a.shard_spec().value().shape; TT_FATAL(K == in0_shard_shape[1], "Error"); - TT_FATAL(in0_shard_shape[1] == program_config.in0_block_w * TILE_WIDTH, "Error"); - TT_FATAL(per_core_M * TILE_HEIGHT == in0_shard_shape[0], "Error"); + TT_FATAL(in0_shard_shape[1] == program_config.in0_block_w * in0_tile_shape[1], "Error"); + TT_FATAL(per_core_M * in0_tile_shape[0] == in0_shard_shape[0], "Error"); if (input_tensor_b.is_sharded()) { TT_FATAL( @@ -1229,7 +1256,7 @@ void Matmul::validate( TT_FATAL(input_tensor_b.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Error"); auto in1_shard_shape = input_tensor_b.shard_spec().value().shape; TT_FATAL(in1_shard_shape[1] == input_tensor_b.get_legacy_shape()[-1], "Error"); - TT_FATAL(per_core_N * TILE_HEIGHT == in1_shard_shape[1], "Error"); + TT_FATAL(per_core_N * in1_tile_shape[1] == in1_shard_shape[1], "Error"); TT_FATAL(in1_shard_shape[0] % K == 0, "Error"); } if (this->output_mem_config.is_sharded()) { @@ -1246,7 +1273,7 @@ void Matmul::validate( std::is_same_v || std::is_same_v) { TT_FATAL( - (input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH) % program_config.in0_block_w == 0, + (input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Kt must be divisible by in0_block_w"); TT_FATAL( program_config.per_core_M % program_config.out_subblock_h == 0, @@ -1287,6 +1314,20 @@ std::vector Matmul::compute_output_shapes(const std:: std::vector Matmul::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); + auto in0_tile_shape = input_tensor_a.get_tile().get_tile_shape(); + auto in1_tile_shape = input_tensor_b.get_tile().get_tile_shape(); + if (this->output_tile.has_value()) { + TT_FATAL(this->output_tile->get_tile_shape()[1] % in1_tile_shape[1] == 0, "the override output tile width be multiple of in1 tile width"); + TT_FATAL(this->output_tile->get_tile_shape()[0] == in0_tile_shape[0], "the override output tile height must equal to the in0 tile height"); + if (this->output_tile->get_tile_shape()[1] != in1_tile_shape[1]) { + TT_FATAL(this->output_tile->get_tile_shape()[0] <= constants::FACE_HEIGHT, "the override output tile height must equal or less to face height"); + } + if (!this->output_mem_config.is_sharded()) { + TT_FATAL(this->output_tile->get_tile_shape()[1] == in1_tile_shape[1], "the override output tile width must equal to the in0 tile width"); + } + } + auto output_tile = this->output_tile.value_or(tt::tt_metal::Tile({in0_tile_shape[0], in1_tile_shape[1]})); + auto tile_width_ratio = output_tile.get_tile_shape()[1] / in1_tile_shape[1]; auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE; TT_FATAL(this->output_dtype.has_value(), "Error"); if (this->output_mem_config.is_sharded()) { @@ -1297,12 +1338,13 @@ std::vector Matmul::create_output_tensors(const std::vector& inp if constexpr (std::is_same_v) { uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] - : input_tensor_a.get_legacy_shape()[-2]) / - TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + : input_tensor_a.get_legacy_shape()[-2]) / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; + TT_FATAL(per_core_N % tile_width_ratio == 0, "per_core_N must be divisible by override output tile width"); + uint32_t num_blocks_y = (M - 1) / per_core_M + 1; uint32_t num_blocks_x = (N - 1) / per_core_N + 1; uint32_t num_blocks_total = num_blocks_y * num_blocks_x; @@ -1310,7 +1352,7 @@ std::vector Matmul::create_output_tensors(const std::vector& inp CoreRangeSet all_cores = num_cores_to_corerange_set(num_cores, program_config.compute_with_storage_grid_size, true); ShardSpec shard_spec = ShardSpec{ - all_cores, {per_core_M * TILE_HEIGHT, per_core_N * TILE_WIDTH}, ShardOrientation::ROW_MAJOR}; + all_cores, {per_core_M * in0_tile_shape[0], per_core_N * in1_tile_shape[1]}, ShardOrientation::ROW_MAJOR}; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; return {create_device_tensor( @@ -1318,20 +1360,23 @@ std::vector Matmul::create_output_tensors(const std::vector& inp this->output_dtype.value(), output_layout, input_tensor_a.device(), - mem_config)}; + mem_config, + output_tile)}; } else if constexpr (std::is_same_v< ProgramConfigType, MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig>) { uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1]; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; auto input_tensor_b_shape = input_tensor_b.get_legacy_shape(); uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; + TT_FATAL(per_core_N % tile_width_ratio == 0, "per_core_N must be divisible by override output tile width"); + CoreRangeSet all_cores = input_tensor_a.shard_spec().value().grid; ShardSpec shard_spec = ShardSpec{ - all_cores, {per_core_M * TILE_HEIGHT, per_core_N * TILE_WIDTH}, ShardOrientation::ROW_MAJOR}; + all_cores, {per_core_M * in0_tile_shape[0], per_core_N * in1_tile_shape[1]}, ShardOrientation::ROW_MAJOR}; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; return {create_device_tensor( @@ -1339,13 +1384,16 @@ std::vector Matmul::create_output_tensors(const std::vector& inp this->output_dtype.value(), output_layout, input_tensor_a.device(), - mem_config)}; + mem_config, + output_tile)}; } else if constexpr (std::is_same_v) { - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; + TT_FATAL(per_core_N % tile_width_ratio == 0, "per_core_N must be divisible by override output tile width"); + uint32_t num_blocks_y = (M - 1) / per_core_M + 1; uint32_t num_blocks_x = (N - 1) / per_core_N + 1; uint32_t num_blocks_total = num_blocks_y * num_blocks_x; @@ -1360,7 +1408,7 @@ std::vector Matmul::create_output_tensors(const std::vector& inp shard_orientation = ShardOrientation::ROW_MAJOR; } ShardSpec shard_spec = - ShardSpec{all_cores, {per_core_M * TILE_HEIGHT, per_core_N * TILE_WIDTH}, shard_orientation}; + ShardSpec{all_cores, {per_core_M * in0_tile_shape[0], per_core_N * in1_tile_shape[1]}, shard_orientation}; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; return {create_device_tensor( @@ -1368,13 +1416,16 @@ std::vector Matmul::create_output_tensors(const std::vector& inp this->output_dtype.value(), output_layout, input_tensor_a.device(), - mem_config)}; + mem_config, + output_tile)}; } else if constexpr (std::is_same_v) { - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / TILE_HEIGHT; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / TILE_WIDTH; + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; + TT_FATAL(per_core_N % tile_width_ratio == 0, "per_core_N must be divisible by override output tile width"); + uint32_t num_blocks_y = (M - 1) / per_core_M + 1; uint32_t num_blocks_x = (N - 1) / per_core_N + 1; uint32_t num_blocks_total = num_blocks_y * num_blocks_x; @@ -1391,7 +1442,7 @@ std::vector Matmul::create_output_tensors(const std::vector& inp program_config.compute_with_storage_grid_size, shard_orientation == ShardOrientation::ROW_MAJOR); ShardSpec shard_spec = - ShardSpec{all_cores, {per_core_M * TILE_HEIGHT, per_core_N * TILE_WIDTH}, shard_orientation}; + ShardSpec{all_cores, {per_core_M * in0_tile_shape[0], per_core_N * in1_tile_shape[1]}, shard_orientation}; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; return {create_device_tensor( @@ -1399,16 +1450,26 @@ std::vector Matmul::create_output_tensors(const std::vector& inp this->output_dtype.value(), output_layout, input_tensor_a.device(), - mem_config)}; + mem_config, + output_tile)}; } else { + TT_FATAL(in0_tile_shape[0] == TILE_HEIGHT and in0_tile_shape[1] == TILE_WIDTH, + "matmul with non-optimized program config does not support tiny tile"); + TT_FATAL(in1_tile_shape[0] == TILE_HEIGHT and in1_tile_shape[1] == TILE_WIDTH, + "matmul with non-optimized program config does not support tiny tile"); + if (this->output_tile.has_value()) { + TT_FATAL(this->output_tile->get_tile_shape()[0] == TILE_HEIGHT and this->output_tile->get_tile_shape()[1] == TILE_WIDTH, + "matmul with non-optimized program config does not support tiny tile"); + } TT_THROW("Unsupported op for output sharding"); return {}; } }, chosen_program_config); } + return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype.value(), Layout::TILE, this->output_mem_config); + *this, input_tensors, this->output_dtype.value(), Layout::TILE, this->output_mem_config, output_tile); } operation::ProgramWithCallbacks Matmul::create_program( diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index 40fd00a02d5..b08678d871e 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -164,6 +164,7 @@ struct Matmul { const bool user_run_batched = false; const bool transpose_a = false; const bool transpose_b = false; + const std::optional output_tile; void validate( const std::vector &input_tensors, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp index 9508fa07358..368b15edec0 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_1d_program_factory.cpp @@ -46,6 +46,10 @@ operation::ProgramWithCallbacks create_program_mcast_in0( tt_metal::Buffer* in1_buffer, tt_metal::Buffer* bias_buffer, tt_metal::Buffer* out_buffer, + const tt::tt_metal::Tile& in0_tile, + const tt::tt_metal::Tile& in1_tile, + const tt::tt_metal::Tile& bias_tile, + const tt::tt_metal::Tile& output_tile, tt::DataFormat in0_data_format, tt::DataFormat in1_data_format, tt::DataFormat bias_data_format, @@ -66,11 +70,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0( ? (fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b) : (fp32_dest_acc_en ? tt::DataFormat::Float32 : output_data_format); - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); - uint32_t bias_single_tile_size = tt_metal::detail::TileSize(bias_data_format); - uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_data_format); - uint32_t interm0_single_tile_size = tt_metal::detail::TileSize(interm0_data_format); + uint32_t in0_single_tile_size = in0_tile.get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = in1_tile.get_tile_size(in1_data_format); + uint32_t bias_single_tile_size = bias_tile.get_tile_size(bias_data_format); + uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); + uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); uint32_t in0_block_tiles = per_core_M * in0_block_w; uint32_t in0_CB_tiles = in0_block_tiles; @@ -83,8 +87,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t in0_shard_width_in_tiles = 0; uint32_t in0_shard_height_in_tiles = 0; if (in0_is_sharded) { - in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / TILE_WIDTH; - in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / TILE_HEIGHT; + in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / in0_tile.get_tile_shape()[1]; + in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / in0_tile.get_tile_shape()[0]; in2_block_tiles = per_core_M * in0_shard_width_in_tiles; } uint32_t in2_CB_tiles = in2_block_tiles; @@ -490,7 +494,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t src0_cb_index = 0; tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}}) - .set_page_size(src0_cb_index, in0_single_tile_size); + .set_page_size(src0_cb_index, in0_single_tile_size) + .set_tile_dims(src0_cb_index, in0_tile); auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, src0_cb_config); log_debug( LogOp, @@ -503,7 +508,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t src1_cb_index = 1; tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}}) - .set_page_size(src1_cb_index, in1_single_tile_size); + .set_page_size(src1_cb_index, in1_single_tile_size) + .set_tile_dims(src1_cb_index, in1_tile); auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, src1_cb_config); log_debug( LogOp, @@ -519,7 +525,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( tt_metal::CircularBufferConfig src2_cb_config = tt_metal::CircularBufferConfig(in2_CB_size, {{src2_cb_index, in0_data_format}}) .set_page_size(src2_cb_index, in0_single_tile_size) - .set_globally_allocated_address(*in0_buffer); + .set_globally_allocated_address(*in0_buffer) + .set_tile_dims(src2_cb_index, in0_tile); cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, src2_cb_config); log_debug( LogOp, @@ -549,13 +556,15 @@ operation::ProgramWithCallbacks create_program_mcast_in0( {output_cb_index, output_data_format}, }; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) - .set_page_size(output_cb_index, output_single_tile_size); + .set_page_size(output_cb_index, output_single_tile_size) + .set_tile_dims(output_cb_index, output_tile); // interm0 std::map interm0_cb_data_format_spec{ {interm0_cb_index, interm0_data_format}, }; interm0_cb_config = tt_metal::CircularBufferConfig(interm0_CB_size, interm0_cb_data_format_spec) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(interm0_cb_index, output_tile); auto cb_interm0 = tt_metal::CreateCircularBuffer(program, CoreRangeSet({all_cores}), interm0_cb_config); log_debug( @@ -571,7 +580,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0( {output_cb_index, output_data_format}, {interm0_cb_index, interm0_data_format}}; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) .set_page_size(output_cb_index, output_single_tile_size) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(output_cb_index, output_tile) + .set_tile_dims(interm0_cb_index, output_tile); } if (output_is_sharded) { @@ -590,7 +601,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0( uint32_t src3_cb_index = 3; tt_metal::CircularBufferConfig cb_src3_config = tt_metal::CircularBufferConfig(in3_CB_size, {{src3_cb_index, bias_data_format}}) - .set_page_size(src3_cb_index, bias_single_tile_size); + .set_page_size(src3_cb_index, bias_single_tile_size) + .set_tile_dims(src3_cb_index, bias_tile); auto cb_src3 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src3_config); log_debug( LogOp, @@ -848,6 +860,10 @@ operation::ProgramWithCallbacks create_program_mcast_in1( tt_metal::Buffer* in1_buffer, tt_metal::Buffer* bias_buffer, tt_metal::Buffer* out_buffer, + const tt::tt_metal::Tile& in0_tile, + const tt::tt_metal::Tile& in1_tile, + const tt::tt_metal::Tile& bias_tile, + const tt::tt_metal::Tile& output_tile, tt::DataFormat in0_data_format, tt::DataFormat in1_data_format, tt::DataFormat bias_data_format, @@ -871,11 +887,11 @@ operation::ProgramWithCallbacks create_program_mcast_in1( ? (fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b) : (fp32_dest_acc_en ? tt::DataFormat::Float32 : output_data_format); - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); - uint32_t bias_single_tile_size = tt_metal::detail::TileSize(bias_data_format); - uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_data_format); - uint32_t interm0_single_tile_size = tt_metal::detail::TileSize(interm0_data_format); + uint32_t in0_single_tile_size = in0_tile.get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = in1_tile.get_tile_size(in1_data_format); + uint32_t bias_single_tile_size = bias_tile.get_tile_size(bias_data_format); + uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); + uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); uint32_t in0_block_tiles = per_core_M * in0_block_w; uint32_t in0_CB_tiles = in0_block_tiles; @@ -890,8 +906,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t in0_shard_height_in_tiles = 0; uint32_t in0_shard_width_in_tiles = 0; if (in0_is_sharded) { - in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / TILE_HEIGHT; - in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / TILE_WIDTH; + in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / in0_tile.get_tile_shape()[0]; + in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / in0_tile.get_tile_shape()[1]; // NOTE: Criteria for extract_shard_sub_blocks is different from mcast in0 // In the reader kernel, always need to copy to cb0 even for height=1 shards since we may not always do mcast // In mcast in0 sharded reader kernel, this is handled by mcast with loopback src @@ -1200,7 +1216,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t src0_cb_index = 0; tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}}) - .set_page_size(src0_cb_index, in0_single_tile_size); + .set_page_size(src0_cb_index, in0_single_tile_size) + .set_tile_dims(src0_cb_index, in0_tile); if (in0_is_sharded and not extract_shard_sub_blocks) { src0_cb_config = src0_cb_config.set_globally_allocated_address(*in0_buffer); } @@ -1219,7 +1236,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( tt_metal::CircularBufferConfig src2_cb_config = tt_metal::CircularBufferConfig(in0_CB_size, {{src2_cb_index, in0_data_format}}) .set_page_size(src2_cb_index, in0_single_tile_size) - .set_globally_allocated_address(*in0_buffer); + .set_globally_allocated_address(*in0_buffer) + .set_tile_dims(src2_cb_index, in0_tile); cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, src2_cb_config); log_debug( LogOp, @@ -1233,7 +1251,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t src1_cb_index = 1; tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}}) - .set_page_size(src1_cb_index, in1_single_tile_size); + .set_page_size(src1_cb_index, in1_single_tile_size) + .set_tile_dims(src1_cb_index, in1_tile); auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, src1_cb_config); log_debug( LogOp, @@ -1256,13 +1275,15 @@ operation::ProgramWithCallbacks create_program_mcast_in1( {output_cb_index, output_data_format}, }; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) - .set_page_size(output_cb_index, output_single_tile_size); + .set_page_size(output_cb_index, output_single_tile_size) + .set_tile_dims(output_cb_index, output_tile); // interm0 std::map interm0_cb_data_format_spec{ {interm0_cb_index, interm0_data_format}, }; interm0_cb_config = tt_metal::CircularBufferConfig(interm0_CB_size, interm0_cb_data_format_spec) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(interm0_cb_index, output_tile); auto cb_interm0 = tt_metal::CreateCircularBuffer(program, CoreRangeSet({all_cores}), interm0_cb_config); log_debug( @@ -1278,7 +1299,9 @@ operation::ProgramWithCallbacks create_program_mcast_in1( {output_cb_index, output_data_format}, {interm0_cb_index, interm0_data_format}}; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) .set_page_size(output_cb_index, output_single_tile_size) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(output_cb_index, output_tile) + .set_tile_dims(interm0_cb_index, output_tile); } if (output_is_sharded) { @@ -1297,7 +1320,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1( uint32_t src3_cb_index = 3; tt_metal::CircularBufferConfig cb_src3_config = tt_metal::CircularBufferConfig(in3_CB_size, {{src3_cb_index, bias_data_format}}) - .set_page_size(src3_cb_index, bias_single_tile_size); + .set_page_size(src3_cb_index, bias_single_tile_size) + .set_tile_dims(src3_cb_index, bias_tile); auto cb_src3 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src3_config); log_debug( LogOp, @@ -1537,6 +1561,12 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( bool untilize_out, std::optional &fused_op_signaler) { const auto &ashape = a.get_legacy_shape(), bshape = b.get_legacy_shape(); + auto in0_tile = a.get_tile(); + auto in1_tile = b.get_tile(); + // cannot use the output tensor tile directly as that might be changed by user override + auto output_tile = tt::tt_metal::Tile({in0_tile.get_tile_shape()[0], in1_tile.get_tile_shape()[1]}); + auto in0_tile_shape = a.get_tile().get_tile_shape(); + auto in1_tile_shape = b.get_tile().get_tile_shape(); // CB dataformats tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); // in0 @@ -1558,8 +1588,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( tt_metal::Device* device = a.device(); - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); + uint32_t in0_single_tile_size = in0_tile.get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = in1_tile.get_tile_size(in1_data_format); tt_metal::Buffer* in0_buffer = a.buffer(); tt_metal::Buffer* in1_buffer = b.buffer(); TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0, "Error"); @@ -1568,10 +1598,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( TT_FATAL( ashape[-1] == bshape[-2], "Dimension K (A.shape[-1] and B.shape[-2]) must match for A and B in bmm_op"); // A.K == B.K - TT_FATAL(ashape[-2] % TILE_HEIGHT == 0, "Error"); - TT_FATAL(ashape[-1] % TILE_WIDTH == 0, "Error"); - TT_FATAL(bshape[-2] % TILE_HEIGHT == 0, "Error"); - TT_FATAL(bshape[-1] % TILE_WIDTH == 0, "Error"); + TT_FATAL(ashape[-2] % in0_tile_shape[0] == 0, "Error"); + TT_FATAL(ashape[-1] % in0_tile_shape[1] == 0, "Error"); + TT_FATAL(bshape[-2] % in1_tile_shape[0] == 0, "Error"); + TT_FATAL(bshape[-1] % in1_tile_shape[1] == 0, "Error"); MathFidelity math_fidelity; bool math_approx_mode; @@ -1611,9 +1641,9 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( // NOTE: Pads matmul input dims to 512 x 512 multiples (ie. multiples of 16*32 x 16*32) // NOTE: Maximum number of tiles in output is 120 * 16^2 = 30,720 (eg. [1, 1, 5120, 6144]) uint32_t B = get_batch_size(ashape); - uint32_t Mt = ashape[-2] / TILE_HEIGHT; - uint32_t Kt = ashape[-1] / TILE_WIDTH; - uint32_t Nt = bshape[-1] / TILE_WIDTH; + uint32_t Mt = ashape[-2] / in0_tile_shape[0]; + uint32_t Kt = ashape[-1] / in0_tile_shape[1]; + uint32_t Nt = bshape[-1] / in1_tile_shape[1]; if (fuse_batch) { Mt = B * Mt; @@ -1672,6 +1702,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( in1_buffer, bias_buffer, out_buffer, + in0_tile, + in1_tile, + bias.has_value() ? bias->get_tile() : output_tile, + output_tile, in0_data_format, in1_data_format, bias_data_format, @@ -1703,6 +1737,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_( in1_buffer, bias_buffer, out_buffer, + in0_tile, + in1_tile, + bias.has_value() ? bias->get_tile() : output_tile, + output_tile, in0_data_format, in1_data_format, bias_data_format, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp index 22a5a0fd615..60f5c842a01 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_2d_program_factory.cpp @@ -46,6 +46,10 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( tt_metal::Buffer* in1_buffer, tt_metal::Buffer* bias_buffer, tt_metal::Buffer* out_buffer, + const tt::tt_metal::Tile& in0_tile, + const tt::tt_metal::Tile& in1_tile, + const tt::tt_metal::Tile& bias_tile, + const tt::tt_metal::Tile& output_tile, tt::DataFormat in0_data_format, tt::DataFormat in1_data_format, tt::DataFormat bias_data_format, @@ -69,11 +73,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( ? (fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b) : (fp32_dest_acc_en ? tt::DataFormat::Float32 : output_data_format); - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); - uint32_t bias_single_tile_size = tt_metal::detail::TileSize(bias_data_format); - uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_data_format); - uint32_t interm0_single_tile_size = tt_metal::detail::TileSize(interm0_data_format); + uint32_t in0_single_tile_size = in0_tile.get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = in1_tile.get_tile_size(in1_data_format); + uint32_t bias_single_tile_size = bias_tile.get_tile_size(bias_data_format); + uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); + uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); const bool in0_block_sharded = in0_memory_layout == TensorMemoryLayout::BLOCK_SHARDED; const bool in0_height_sharded = in0_memory_layout == TensorMemoryLayout::HEIGHT_SHARDED; @@ -103,8 +107,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( uint32_t in0_shard_width_in_tiles = 0; uint32_t in0_shard_height_in_tiles = 0; if (in0_is_sharded) { - in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / TILE_WIDTH; - in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / TILE_HEIGHT; + in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / in0_tile.get_tile_shape()[1]; + in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / in0_tile.get_tile_shape()[0]; in2_block_tiles = per_core_M * in0_shard_width_in_tiles; } @@ -661,7 +665,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( uint32_t src0_cb_index = 0; tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}}) - .set_page_size(src0_cb_index, in0_single_tile_size); + .set_page_size(src0_cb_index, in0_single_tile_size) + .set_tile_dims(src0_cb_index, in0_tile); if (in0_height_sharded) { src0_cb_config.set_globally_allocated_address(*in0_buffer); } @@ -677,7 +682,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( uint32_t src1_cb_index = 1; tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}}) - .set_page_size(src1_cb_index, in1_single_tile_size); + .set_page_size(src1_cb_index, in1_single_tile_size) + .set_tile_dims(src1_cb_index, in1_tile); if (in1_is_sharded and not in1_is_dram) { src1_cb_config.set_globally_allocated_address(*in1_buffer); } @@ -696,7 +702,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( tt_metal::CircularBufferConfig src2_cb_config = tt_metal::CircularBufferConfig(in2_CB_size, {{src2_cb_index, in0_data_format}}) .set_page_size(src2_cb_index, in0_single_tile_size) - .set_globally_allocated_address(*in0_buffer); + .set_globally_allocated_address(*in0_buffer) + .set_tile_dims(src2_cb_index, in0_tile); cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores, src2_cb_config); log_debug( LogOp, @@ -726,13 +733,15 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( {output_cb_index, output_data_format}, }; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) - .set_page_size(output_cb_index, output_single_tile_size); + .set_page_size(output_cb_index, output_single_tile_size) + .set_tile_dims(output_cb_index, output_tile); // interm0 std::map interm0_cb_data_format_spec{ {interm0_cb_index, interm0_data_format}, }; interm0_cb_config = tt_metal::CircularBufferConfig(interm0_CB_size, interm0_cb_data_format_spec) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(interm0_cb_index, output_tile); auto cb_interm0 = tt_metal::CreateCircularBuffer(program, CoreRangeSet({all_cores}), interm0_cb_config); log_debug( @@ -748,7 +757,9 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( {output_cb_index, output_data_format}, {interm0_cb_index, interm0_data_format}}; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) .set_page_size(output_cb_index, output_single_tile_size) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(output_cb_index, output_tile) + .set_tile_dims(interm0_cb_index, output_tile); } if (output_is_sharded) { @@ -768,7 +779,8 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1( uint32_t src3_cb_index = 3; tt_metal::CircularBufferConfig cb_src3_config = tt_metal::CircularBufferConfig(in3_CB_size, {{src3_cb_index, bias_data_format}}) - .set_page_size(src3_cb_index, bias_single_tile_size); + .set_page_size(src3_cb_index, bias_single_tile_size) + .set_tile_dims(src3_cb_index, bias_tile); auto cb_src3 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src3_config); log_debug( LogOp, @@ -1238,6 +1250,12 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( bool untilize_out, std::optional &fused_op_signaler) { const auto &ashape = a.get_legacy_shape(), bshape = b.get_legacy_shape(); + auto in0_tile = a.get_tile(); + auto in1_tile = b.get_tile(); + // cannot use the output tensor tile directly as that might be changed by user override + auto output_tile = tt::tt_metal::Tile({in0_tile.get_tile_shape()[0], in1_tile.get_tile_shape()[1]}); + auto in0_tile_shape = a.get_tile().get_tile_shape(); + auto in1_tile_shape = b.get_tile().get_tile_shape(); // CB dataformats tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); // in0 @@ -1259,8 +1277,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( tt_metal::Device* device = a.device(); - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); + uint32_t in0_single_tile_size = in0_tile.get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = in1_tile.get_tile_size(in1_data_format); tt_metal::Buffer* in0_buffer = a.buffer(); tt_metal::Buffer* in1_buffer = b.buffer(); TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0, "Error"); @@ -1269,10 +1287,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( TT_FATAL( ashape[-1] == bshape[-2], "Dimension K (A.shape[-1] and B.shape[-2]) must match for A and B in bmm_op"); // A.K == B.K - TT_FATAL(ashape[-2] % TILE_HEIGHT == 0, "Error"); - TT_FATAL(ashape[-1] % TILE_WIDTH == 0, "Error"); - TT_FATAL(bshape[-2] % TILE_HEIGHT == 0, "Error"); - TT_FATAL(bshape[-1] % TILE_WIDTH == 0, "Error"); + TT_FATAL(ashape[-2] % in0_tile_shape[0] == 0, "Error"); + TT_FATAL(ashape[-1] % in0_tile_shape[1] == 0, "Error"); + TT_FATAL(bshape[-2] % in1_tile_shape[0] == 0, "Error"); + TT_FATAL(bshape[-1] % in1_tile_shape[1] == 0, "Error"); MathFidelity math_fidelity; bool math_approx_mode; @@ -1312,9 +1330,9 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( // NOTE: Pads matmul input dims to 512 x 512 multiples (ie. multiples of 16*32 x 16*32) // NOTE: Maximum number of tiles in output is 120 * 16^2 = 30,720 (eg. [1, 1, 5120, 6144]) uint32_t B = get_batch_size(ashape); - uint32_t Mt = ashape[-2] / TILE_HEIGHT; - uint32_t Kt = ashape[-1] / TILE_WIDTH; - uint32_t Nt = bshape[-1] / TILE_WIDTH; + uint32_t Mt = ashape[-2] / in0_tile_shape[0]; + uint32_t Kt = ashape[-1] / in0_tile_shape[1]; + uint32_t Nt = bshape[-1] / in1_tile_shape[1]; if (fuse_batch) { Mt = B * Mt; @@ -1378,6 +1396,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_( in1_buffer, bias_buffer, out_buffer, + in0_tile, + in1_tile, + bias.has_value() ? bias->get_tile() : output_tile, + output_tile, in0_data_format, in1_data_format, bias_data_format, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp index 0c69e58873c..32d3ae6e6bf 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_mcast_dram_sharded_program_factory.cpp @@ -351,6 +351,10 @@ operation::ProgramWithCallbacks create_program_dram_sharded( tt_metal::Buffer* in1_buffer, tt_metal::Buffer* bias_buffer, tt_metal::Buffer* out_buffer, + const tt::tt_metal::Tile& in0_tile, + const tt::tt_metal::Tile& in1_tile, + const tt::tt_metal::Tile& bias_tile, + const tt::tt_metal::Tile& output_tile, tt::DataFormat in0_data_format, tt::DataFormat in1_data_format, tt::DataFormat bias_data_format, @@ -428,11 +432,11 @@ operation::ProgramWithCallbacks create_program_dram_sharded( : (fp32_dest_acc_en ? tt::DataFormat::Float32 : output_data_format); interm0_data_format = tt::DataFormat::Float16_b; - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); - uint32_t bias_single_tile_size = tt_metal::detail::TileSize(bias_data_format); - uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_data_format); - uint32_t interm0_single_tile_size = tt_metal::detail::TileSize(interm0_data_format); + uint32_t in0_single_tile_size = in0_tile.get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = in1_tile.get_tile_size(in1_data_format); + uint32_t bias_single_tile_size = bias_tile.get_tile_size(bias_data_format); + uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); + uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); uint32_t in0_block_tiles = per_core_M * in0_block_w; uint32_t in0_CB_tiles = in0_block_tiles; @@ -456,8 +460,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded( uint32_t out_reshard_CB_tiles = out_reshard_block_tiles; // No double buffer uint32_t out_reshard_CB_size = out_reshard_CB_tiles * output_single_tile_size; - uint32_t in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / TILE_WIDTH; - uint32_t in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / TILE_HEIGHT; + uint32_t in0_shard_width_in_tiles = in0_buffer->shard_spec().shape()[1] / in0_tile.get_tile_shape()[1]; + uint32_t in0_shard_height_in_tiles = in0_buffer->shard_spec().shape()[0] / in0_tile.get_tile_shape()[0]; uint32_t in2_block_tiles = per_core_M * in0_shard_width_in_tiles; uint32_t in2_CB_tiles = in2_block_tiles; uint32_t in2_CB_size = in2_CB_tiles * in0_single_tile_size; @@ -711,7 +715,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded( uint32_t src0_cb_index = 0; tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}}) - .set_page_size(src0_cb_index, in0_single_tile_size); + .set_page_size(src0_cb_index, in0_single_tile_size) + .set_tile_dims(src0_cb_index, in0_tile); auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores_in_rect_grid, src0_cb_config); log_debug( LogOp, @@ -724,7 +729,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded( uint32_t src1_cb_index = 1; tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}}) - .set_page_size(src1_cb_index, in1_single_tile_size); + .set_page_size(src1_cb_index, in1_single_tile_size) + .set_tile_dims(src1_cb_index, in1_tile); auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores_in_rect_grid, src1_cb_config); log_debug( LogOp, @@ -738,6 +744,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded( tt_metal::CircularBufferConfig src2_cb_config = tt_metal::CircularBufferConfig(in2_CB_size, {{src2_cb_index, in0_data_format}}) .set_page_size(src2_cb_index, in0_single_tile_size) + .set_tile_dims(src2_cb_index, in0_tile) .set_globally_allocated_address(*in0_buffer); auto cb_src2 = tt_metal::CreateCircularBuffer(program, all_cores_in_rect_grid, src2_cb_config); log_debug( @@ -761,13 +768,15 @@ operation::ProgramWithCallbacks create_program_dram_sharded( {output_cb_index, output_data_format}, }; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) - .set_page_size(output_cb_index, output_single_tile_size); + .set_page_size(output_cb_index, output_single_tile_size) + .set_tile_dims(output_cb_index, output_tile); // interm0 std::map interm0_cb_data_format_spec{ {interm0_cb_index, interm0_data_format}, }; interm0_cb_config = tt_metal::CircularBufferConfig(interm0_CB_size, interm0_cb_data_format_spec) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(interm0_cb_index, output_tile); auto cb_interm0 = tt_metal::CreateCircularBuffer(program, all_cores_in_rect_grid, interm0_cb_config); log_debug( @@ -784,7 +793,9 @@ operation::ProgramWithCallbacks create_program_dram_sharded( {output_cb_index, output_data_format}, {interm0_cb_index, interm0_data_format}}; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) .set_page_size(output_cb_index, output_single_tile_size) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(output_cb_index, output_tile) + .set_tile_dims(interm0_cb_index, output_tile); } auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores_in_rect_grid, output_cb_config); log_debug( @@ -802,7 +813,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded( }; tt_metal::CircularBufferConfig output_reshard_cb_config = tt_metal::CircularBufferConfig(out_reshard_CB_size, output_reshard_cb_data_format_spec) - .set_page_size(output_reshard_cb_index, output_single_tile_size); + .set_page_size(output_reshard_cb_index, output_single_tile_size) + .set_tile_dims(output_reshard_cb_index, output_tile); output_reshard_cb_config = output_reshard_cb_config.set_globally_allocated_address(*out_buffer); auto cb_output_reshard = tt_metal::CreateCircularBuffer(program, all_cores_in_rect_grid, output_reshard_cb_config); @@ -810,7 +822,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded( uint32_t src3_cb_index = 3; tt_metal::CircularBufferConfig cb_src3_config = tt_metal::CircularBufferConfig(in3_CB_size, {{src3_cb_index, bias_data_format}}) - .set_page_size(src3_cb_index, bias_single_tile_size); + .set_page_size(src3_cb_index, bias_single_tile_size) + .set_tile_dims(src3_cb_index, bias_tile); auto cb_src3 = tt_metal::CreateCircularBuffer(program, all_cores_in_rect_grid, cb_src3_config); log_debug( LogOp, @@ -1137,6 +1150,12 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( bool skip_in0_mcast, bool skip_write_back) { const auto &ashape = a.get_legacy_shape(), bshape = b.get_legacy_shape(); + auto in0_tile = a.get_tile(); + auto in1_tile = b.get_tile(); + // cannot use the output tensor tile directly as that might be changed by user override + auto output_tile = tt::tt_metal::Tile({in0_tile.get_tile_shape()[0], in1_tile.get_tile_shape()[1]}); + auto in0_tile_shape = a.get_tile().get_tile_shape(); + auto in1_tile_shape = b.get_tile().get_tile_shape(); // CB dataformats tt::DataFormat in0_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); // in0 @@ -1161,8 +1180,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( TT_FATAL(a.shard_spec().has_value() && output.shard_spec().has_value(), "Error"); CoreRangeSet all_cores_storage = a.shard_spec().value().grid; - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); + uint32_t in0_single_tile_size = a.get_tile().get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = b.get_tile().get_tile_size(in1_data_format); tt_metal::Buffer* in0_buffer = a.buffer(); tt_metal::Buffer* in1_buffer = b.buffer(); TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0, "Error"); @@ -1171,10 +1190,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( TT_FATAL( ashape[-1] == bshape[-2], "Dimension K (A.shape[-1] and B.shape[-2]) must match for A and B in bmm_op"); // A.K == B.K - TT_FATAL(ashape[-2] % TILE_HEIGHT == 0, "Error"); - TT_FATAL(ashape[-1] % TILE_WIDTH == 0, "Error"); - TT_FATAL(bshape[-2] % TILE_HEIGHT == 0, "Error"); - TT_FATAL(bshape[-1] % TILE_WIDTH == 0, "Error"); + TT_FATAL(ashape[-2] % in0_tile_shape[0] == 0, "Error"); + TT_FATAL(ashape[-1] % in0_tile_shape[1] == 0, "Error"); + TT_FATAL(bshape[-2] % in1_tile_shape[0] == 0, "Error"); + TT_FATAL(bshape[-1] % in1_tile_shape[1] == 0, "Error"); MathFidelity math_fidelity; bool math_approx_mode; @@ -1208,9 +1227,9 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( // NOTE: Pads matmul input dims to 512 x 512 multiples (ie. multiples of 16*32 x 16*32) // NOTE: Maximum number of tiles in output is 120 * 16^2 = 30,720 (eg. [1, 1, 5120, 6144]) uint32_t B = 1; - uint32_t Mt = get_batch_size(ashape) * ashape[-2] / TILE_HEIGHT; - uint32_t Kt = ashape[-1] / TILE_WIDTH; - uint32_t Nt = bshape[-1] / TILE_WIDTH; + uint32_t Mt = get_batch_size(ashape) * ashape[-2] / in0_tile_shape[0]; + uint32_t Kt = ashape[-1] / in0_tile_shape[1]; + uint32_t Nt = bshape[-1] / in1_tile_shape[1]; TT_FATAL(Kt % in0_block_w == 0, "Error"); @@ -1242,6 +1261,10 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_( in1_buffer, bias_buffer, out_buffer, + in0_tile, + in1_tile, + bias.has_value() ? bias->get_tile() : output_tile, + output_tile, in0_data_format, in1_data_format, bias_data_format, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_optimized_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_optimized_program_factory.cpp index b054b059f91..636e31e9578 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_optimized_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_reuse_optimized_program_factory.cpp @@ -61,10 +61,14 @@ operation::ProgramWithCallbacks create_program( ? (fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b) : (fp32_dest_acc_en ? tt::DataFormat::Float32 : output_data_format); - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); - uint32_t output_single_tile_size = tt_metal::detail::TileSize(output_data_format); - uint32_t interm0_single_tile_size = tt_metal::detail::TileSize(interm0_data_format); + auto in0_tile = in0.get_tile(); + auto in1_tile = in1.get_tile(); + // cannot use the output tensor tile directly as that might be changed by user override + auto output_tile = tt::tt_metal::Tile({in0_tile.get_tile_shape()[0], in1_tile.get_tile_shape()[1]}); + uint32_t in0_single_tile_size = in0_tile.get_tile_size(in0_data_format); + uint32_t in1_single_tile_size = in1_tile.get_tile_size(in1_data_format); + uint32_t output_single_tile_size = output_tile.get_tile_size(output_data_format); + uint32_t interm0_single_tile_size = output_tile.get_tile_size(interm0_data_format); tt_metal::Buffer* in0_buffer = in0.buffer(); tt_metal::Buffer* in1_buffer = in1.buffer(); @@ -176,13 +180,17 @@ operation::ProgramWithCallbacks create_program( program, "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0.cpp", all_cores, - ReaderDataMovementConfig(reader_compile_time_args, mm_kernel_in0_reader_defines)); + ReaderDataMovementConfig( + reader_compile_time_args, + mm_kernel_in0_reader_defines)); KernelHandle mm_kernel_in1_reader_writer_id = tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_writer_bmm_tile_layout_in1.cpp", all_cores, - WriterDataMovementConfig(reader_writer_compile_time_args, mm_kernel_in1_reader_writer_defines)); + WriterDataMovementConfig( + reader_writer_compile_time_args, + mm_kernel_in1_reader_writer_defines)); vector compute_kernel_args_group_1 = { in0_block_w, // in0_block_w @@ -261,7 +269,8 @@ operation::ProgramWithCallbacks create_program( uint32_t src0_cb_index = 0; tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(in0_CB_size, {{src0_cb_index, in0_data_format}}) - .set_page_size(src0_cb_index, in0_single_tile_size); + .set_page_size(src0_cb_index, in0_single_tile_size) + .set_tile_dims(src0_cb_index, in0_tile); if (in0_is_sharded) { cb_src0_config = cb_src0_config.set_globally_allocated_address(*in0_buffer); } @@ -270,7 +279,8 @@ operation::ProgramWithCallbacks create_program( uint32_t src1_cb_index = 1; tt_metal::CircularBufferConfig cb_src1_config = tt_metal::CircularBufferConfig(in1_CB_size, {{src1_cb_index, in1_data_format}}) - .set_page_size(src1_cb_index, in1_single_tile_size); + .set_page_size(src1_cb_index, in1_single_tile_size) + .set_tile_dims(src1_cb_index, in1_tile); if (in1_is_sharded) { cb_src1_config = cb_src1_config.set_globally_allocated_address(*in1_buffer); } @@ -289,13 +299,15 @@ operation::ProgramWithCallbacks create_program( {output_cb_index, output_data_format}, }; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) - .set_page_size(output_cb_index, output_single_tile_size); + .set_page_size(output_cb_index, output_single_tile_size) + .set_tile_dims(output_cb_index, output_tile); // interm0 std::map interm0_cb_data_format_spec{ {interm0_cb_index, interm0_data_format}, }; interm0_cb_config = tt_metal::CircularBufferConfig(interm0_CB_size, interm0_cb_data_format_spec) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(interm0_cb_index, output_tile); auto cb_interm0 = tt_metal::CreateCircularBuffer(program, CoreRangeSet({all_cores}), interm0_cb_config); } else { @@ -304,7 +316,9 @@ operation::ProgramWithCallbacks create_program( {output_cb_index, output_data_format}, {interm0_cb_index, interm0_data_format}}; output_cb_config = tt_metal::CircularBufferConfig(out_CB_size, output_cb_data_format_spec) .set_page_size(output_cb_index, output_single_tile_size) - .set_page_size(interm0_cb_index, interm0_single_tile_size); + .set_page_size(interm0_cb_index, interm0_single_tile_size) + .set_tile_dims(output_cb_index, output_tile) + .set_tile_dims(interm0_cb_index, output_tile); } // std::map output_cb_data_format_spec { // {output_cb_index, output_data_format}, @@ -474,6 +488,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_optimized_( bool untilize_out) { const auto& ashape = a.get_legacy_shape(); const auto& bshape = b.get_legacy_shape(); + auto in0_tile_shape = a.get_tile().get_tile_shape(); + auto in1_tile_shape = b.get_tile().get_tile_shape(); TT_FATAL( (bcast_batch == false) or (ashape[0] == 1) or (ashape.rank() == 2), @@ -486,8 +502,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_optimized_( tt_metal::Device* device = a.device(); - uint32_t in0_single_tile_size = tt_metal::detail::TileSize(in0_data_format); - uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format); tt_metal::Buffer* in0_buffer = a.buffer(); tt_metal::Buffer* in1_buffer = b.buffer(); @@ -529,9 +543,9 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_optimized_( // NOTE: Only supports matmuls where output is blocks of 16 x 16 tiles (ie. multiples of 16*32 x 16*32) // NOTE: Maximum number of tiles in output is 120 * 16^2 = 30,720 (eg. [1, 1, 5120, 6144]) uint32_t B = get_batch_size(ashape); - uint32_t Mt = ashape[-2] / TILE_HEIGHT; - uint32_t Kt = ashape[-1] / TILE_WIDTH; - uint32_t Nt = bshape[-1] / TILE_WIDTH; + uint32_t Mt = ashape[-2] / in0_tile_shape[0]; + uint32_t Kt = ashape[-1] / in0_tile_shape[1]; + uint32_t Nt = bshape[-1] / in1_tile_shape[1]; // TODO: Generalize TT_FATAL(!fuse_batch, "Only fuse_batch=false is supported for optimized bmm!"); diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul.cpp index f6cff72c332..aaab35da96b 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul.cpp @@ -103,7 +103,8 @@ Tensor MatmulOperation::invoke( const std::optional program_config, const std::optional& activation, const std::optional compute_kernel_config, - const std::optional core_grid) { + const std::optional core_grid, + const std::optional& output_tile) { std::optional user_core_coord; if (core_grid.has_value()) { user_core_coord = CoreCoord(core_grid->x, core_grid->y); @@ -124,7 +125,8 @@ Tensor MatmulOperation::invoke( get_fused_activation(activation), user_run_batched, transpose_a, - transpose_b}, + transpose_b, + output_tile}, /*queue_id=*/0); } @@ -139,7 +141,8 @@ Tensor LinearOperation::invoke( const std::optional program_config, const std::optional& activation, const std::optional compute_kernel_config, - const std::optional core_grid) { + const std::optional core_grid, + const std::optional& output_tile) { std::optional user_core_coord; if (core_grid.has_value()) { user_core_coord = CoreCoord(core_grid->x, core_grid->y); @@ -162,7 +165,8 @@ Tensor LinearOperation::invoke( get_fused_activation(activation), /*user_run_batched=*/false, transpose_a, - transpose_b}, + transpose_b, + output_tile}, /*queue_id=*/0); } diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul.hpp b/ttnn/cpp/ttnn/operations/matmul/matmul.hpp index b13854d483d..54b120fc658 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul.hpp @@ -46,7 +46,8 @@ struct MatmulOperation { const std::optional program_config = std::nullopt, const std::optional& activation = std::nullopt, const std::optional compute_kernel_config = std::nullopt, - const std::optional core_grid = std::nullopt); + const std::optional core_grid = std::nullopt, + const std::optional& output_tile = std::nullopt); }; struct LinearOperation { @@ -61,7 +62,8 @@ struct LinearOperation { const std::optional program_config = std::nullopt, const std::optional& activation = std::nullopt, const std::optional compute_kernel_config = std::nullopt, - const std::optional core_grid = std::nullopt); + const std::optional core_grid = std::nullopt, + const std::optional& output_tile = std::nullopt); }; } // namespace matmul diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp index 93a84b5d44f..56fe86980e6 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp @@ -246,7 +246,8 @@ void py_module(py::module& module) { const std::optional program_config, const std::optional& activation, const std::optional compute_kernel_config, - const std::optional core_grid) -> ttnn::Tensor { + const std::optional core_grid, + const std::optional& output_tile) -> ttnn::Tensor { return self( input_tensor_a, input_tensor_b, @@ -257,7 +258,8 @@ void py_module(py::module& module) { program_config, activation, compute_kernel_config, - core_grid); + core_grid, + output_tile); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), @@ -270,6 +272,7 @@ void py_module(py::module& module) { py::arg("activation") = std::nullopt, py::arg("compute_kernel_config") = std::nullopt, py::arg("core_grid") = std::nullopt, + py::arg("output_tile") = std::nullopt, }); bind_registered_operation( @@ -314,7 +317,8 @@ void py_module(py::module& module) { const std::optional program_config, const std::optional& activation, const std::optional compute_kernel_config, - const std::optional core_grid) -> ttnn::Tensor { + const std::optional core_grid, + const std::optional& output_tile) -> ttnn::Tensor { return self( input_tensor_a, input_tensor_b, @@ -326,7 +330,8 @@ void py_module(py::module& module) { program_config, activation, compute_kernel_config, - core_grid); + core_grid, + output_tile); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), @@ -340,6 +345,7 @@ void py_module(py::module& module) { py::arg("activation") = std::nullopt, py::arg("compute_kernel_config") = std::nullopt, py::arg("core_grid") = std::nullopt, + py::arg("output_tile") = std::nullopt, }); } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 8979d11fe47..0617c7db17b 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -257,6 +257,7 @@ void copy_block(uint32_t in_cb, uint32_t out_cb, uint32_t num_tiles) { cb_wait_front(in_cb, num_tiles); cb_reserve_back(out_cb, num_tiles); + #pragma GCC unroll 0 for (uint32_t i = 0; i < num_tiles; i++) { acquire_dst(tt::DstMode::Half); copy_tile(in_cb, i, 0/*dst*/); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp index a406997bcf8..26188cc8923 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp @@ -376,7 +376,7 @@ void copy_block(uint32_t in_cb, uint32_t out_cb, uint32_t num_tiles) { cb_pop_front(in_cb, num_tiles); } -void cb_matmul_blocks(const uint32_t& in0_cb, const uint32_t& in1_cb, const uint32_t& out_cb, const uint32_t& M, const uint32_t& N, const uint32_t& K, const uint32_t& num_blocks, const uint32_t& in0_num_subblocks, const uint32_t& in1_num_subblocks, +ALWI void cb_matmul_blocks(const uint32_t& in0_cb, const uint32_t& in1_cb, const uint32_t& out_cb, const uint32_t& M, const uint32_t& N, const uint32_t& K, const uint32_t& num_blocks, const uint32_t& in0_num_subblocks, const uint32_t& in1_num_subblocks, const uint32_t& in0_block_w, const uint32_t& subblock_h, const uint32_t& subblock_w, const bool& transpose) { // precondition: in0_cb has M*K produced // preconditino: in1_cb has K*N produced diff --git a/ttnn/cpp/ttnn/run_operation.hpp b/ttnn/cpp/ttnn/run_operation.hpp index 7524d4b941e..3286146dfd8 100644 --- a/ttnn/cpp/ttnn/run_operation.hpp +++ b/ttnn/cpp/ttnn/run_operation.hpp @@ -23,7 +23,8 @@ auto generic_create_output_tensors( const Tensors& input_tensors, const std::optional output_dtype, const Layout output_layout, - const std::optional& output_mem_config) -> ProgramOutputTensors { + const std::optional& output_mem_config, + const std::optional& tile = std::nullopt) -> ProgramOutputTensors { const auto& input_tensor = input_tensors.at(0); const auto& output_shapes = operation.compute_output_shapes(input_tensors); @@ -36,7 +37,7 @@ auto generic_create_output_tensors( output_dtype.value_or(input_tensors.at(0).get_dtype()), output_layout, input_tensor.device(), - output_mem_config.value_or(input_tensors.at(0).memory_config()))); + output_mem_config.value_or(input_tensors.at(0).memory_config()), tile)); } return output_tensors; } diff --git a/ttnn/cpp/ttnn/run_operation_inl.hpp b/ttnn/cpp/ttnn/run_operation_inl.hpp index 5c2bec5720a..9a97af0b4af 100644 --- a/ttnn/cpp/ttnn/run_operation_inl.hpp +++ b/ttnn/cpp/ttnn/run_operation_inl.hpp @@ -232,6 +232,7 @@ void launch_op( output_tensor->tensor_attributes->shape = local_tensor->tensor_attributes->shape; output_tensor->tensor_attributes->dtype = local_tensor->tensor_attributes->dtype; output_tensor->tensor_attributes->layout = local_tensor->tensor_attributes->layout; + output_tensor->tensor_attributes->tile = local_tensor->tensor_attributes->tile; output_tensor->tensor_attributes->metadata_populated = true; } } diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 2634ca5f295..d72b0bf50f8 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -28,10 +28,20 @@ namespace tt { namespace tt_metal { -Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout) : +Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout, const std::optional& tile) : tensor_id{std::nullopt}, - tensor_attributes(std::make_shared(storage, shape, dtype, layout)), deallocate_through_destructor(false) { + + if (tile.has_value()) { + tensor_attributes = std::make_shared(storage, shape, dtype, layout, tile.value()); + + if (tile->get_tile_shape()[0] != TILE_WIDTH or tile->get_tile_shape()[1] != TILE_HEIGHT) { + tt::log_warning("only matmul op currently support the customized tile shape: {}", tile->get_tile_shape()); + } + } else { + tensor_attributes = std::make_shared(storage, shape, dtype, layout); + } + ZoneScoped; std::visit( [&](auto&& storage) { @@ -81,8 +91,8 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L this->tensor_attributes->metadata_populated = true; } -Tensor::Tensor(const Storage storage, const tt::tt_metal::LegacyShape shape, DataType dtype, Layout layout) : - Tensor(storage, ttnn::Shape{shape}, dtype, layout) {} +Tensor::Tensor(const Storage storage, const tt::tt_metal::LegacyShape shape, DataType dtype, Layout layout, const std::optional& tile) : + Tensor(storage, ttnn::Shape{shape}, dtype, layout, tile) {} Tensor::~Tensor() { ZoneScoped; @@ -260,6 +270,7 @@ void Tensor::deepcopy(const Tensor& other) { this->set_storage(other.get_storage()); this->set_dtype(other.get_dtype()); this->set_layout(other.get_layout()); + this->set_tile(other.get_tile()); // Set metadata populated flag for getters this->tensor_attributes->metadata_populated = true; this->tensor_attributes->num_workers_completed++; @@ -272,6 +283,7 @@ void Tensor::populate_buffers_and_metadata(const Tensor& other) { this->set_shape(other.get_shape()); this->set_dtype(other.get_dtype()); this->set_layout(other.get_layout()); + this->set_tile(other.get_tile()); // Populate storage container with buffers + shapes std::visit( [this](auto&& storage) { @@ -361,6 +373,10 @@ const Layout& Tensor::get_layout() const { this->wait_for_tensor_metadata_populated(); return this->tensor_attributes->layout; } +const Tile& Tensor::get_tile() const { + this->wait_for_tensor_metadata_populated(); + return this->tensor_attributes->tile; +} const Storage& Tensor::get_storage() const { this->wait_for_tensor_data_populated(); @@ -501,7 +517,7 @@ uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_ uint32_t Tensor::intended_volume() const { return tt::tt_metal::compute_volume(this->get_shape()); } Tensor create_device_tensor( - const tt::tt_metal::LegacyShape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) { + const tt::tt_metal::LegacyShape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { ZoneScoped; GraphTracker::instance().track_function_start("tt::tt_metal::create_device_tensor", shape, data_type, layout, device, memory_config); if (memory_config.is_sharded()) { @@ -517,15 +533,15 @@ Tensor create_device_tensor( } auto element_size = tensor_impl::element_size_bytes(data_type); - auto page_shape = tensor_impl::get_sharded_page_shape(layout, data_type, shard_spec.shape); + auto page_shape = tensor_impl::get_sharded_page_shape(layout, data_type, shard_spec.shape, tile); std::array tensor2d_size = {other_dims / page_shape[0], width / page_shape[1]}; ShardSpecBuffer shard_spec_buffer(shard_spec, page_shape, tensor2d_size); size_t packed_size_in_bytes = tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type)); auto device_buffer = tensor_impl::allocate_buffer_on_device( - packed_size_in_bytes, device, shape, data_type, layout, memory_config, shard_spec_buffer); + packed_size_in_bytes, device, shape, data_type, layout, memory_config, shard_spec_buffer, tile); - auto output = Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); + auto output = Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, tile); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -533,8 +549,8 @@ Tensor create_device_tensor( size_t packed_size_in_bytes = tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type)); auto device_buffer = tensor_impl::allocate_buffer_on_device( - packed_size_in_bytes, device, shape, data_type, layout, memory_config); - auto output = Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); + packed_size_in_bytes, device, shape, data_type, layout, memory_config, std::nullopt, tile); + auto output = Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, tile); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -661,12 +677,12 @@ void memcpy(Tensor& dst, const Tensor& src, const std::optional tra } Tensor allocate_tensor_on_device( - const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) { + const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { // Top level wrapper to asynchronously create a device tensor (single device) Tensor device_tensor = Tensor({device}); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); - device->push_work([shape, data_type, layout, device, memory_config, device_tensor]() mutable { - auto local_tensor = create_device_tensor(shape.value, data_type, layout, device, memory_config); + device->push_work([shape, data_type, layout, device, memory_config, tile, device_tensor]() mutable { + auto local_tensor = create_device_tensor(shape.value, data_type, layout, device, memory_config, tile); device_tensor.populate_buffers_and_metadata(local_tensor); }); device_tensor.tensor_attributes->update_main_thread_ref_count(device, device_tensor_ref_count); @@ -678,7 +694,9 @@ Tensor allocate_tensor_on_device( DataType data_type, Layout layout, MeshDevice* mesh_device, - const MemoryConfig& memory_config) { + const MemoryConfig& memory_config, + const std::optional& tile + ) { // Top level wrapper to asynchronously create a device tensor (multi-device) Tensor device_tensor = Tensor(mesh_device->get_devices()); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); @@ -687,8 +705,8 @@ Tensor allocate_tensor_on_device( for (int worker_index = 0; worker_index < num_workers; ++worker_index) { auto& worker = workers[worker_index]; - worker->push_work([shape, data_type, layout, worker, memory_config, device_tensor, worker_index]() mutable { - auto local_tensor = create_device_tensor(shape.value, data_type, layout, worker, memory_config); + worker->push_work([shape, data_type, layout, worker, memory_config, tile, device_tensor, worker_index]() mutable { + auto local_tensor = create_device_tensor(shape.value, data_type, layout, worker, memory_config, tile); insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; @@ -696,6 +714,9 @@ Tensor allocate_tensor_on_device( device_tensor.set_shape(ttnn::Shape(shape)); device_tensor.set_dtype(data_type); device_tensor.set_layout(layout); + if (tile.has_value()) { + device_tensor.set_tile(tile.value()); + } device_tensor.tensor_attributes->metadata_populated = true; } }); @@ -726,6 +747,7 @@ void write_tensor(Tensor host_tensor, Tensor device_tensor, uint8_t cq_id) { TT_FATAL(async_safe_tensor.get_shape() == device_tensor.get_shape(), "Error"); TT_FATAL(async_safe_tensor.get_dtype() == device_tensor.get_dtype(), "Error"); TT_FATAL(async_safe_tensor.get_layout() == device_tensor.get_layout(), "Error"); + TT_FATAL(async_safe_tensor.get_tile() == device_tensor.get_tile(), "Error"); std::visit( [worker_index, worker, cq_id, &async_safe_tensor](auto&& s) { void* host_data = nullptr; diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index d053a98b51f..cc3a81cc0e5 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -18,6 +18,7 @@ #include "ttnn/common/constants.hpp" #include "ttnn/tensor/types.hpp" #include "tt_metal/impl/buffers/buffer.hpp" +#include "tt_metal/impl/tile/tile.hpp" #include "tt_metal/impl/device/device.hpp" #include "tt_metal/impl/device/mesh_device.hpp" #include "tt_metal/tt_stl/reflection.hpp" @@ -32,6 +33,7 @@ struct Tensor { ttnn::Shape shape; DataType dtype; Layout layout; + Tile tile; uint32_t num_shards_to_be_populated = 0; uint32_t main_thread_ref_count = 0; std::atomic num_sibling_workers_sharing_tensor = 0; @@ -41,10 +43,10 @@ struct Tensor { bool deallocated = false; // Set to true if device side storage was deallocated bool dynamic_storage = false; // Storage type can change, depending on op behaviour bool track_ref_count = false; - TensorAttributes(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout) : - storage(storage), shape(shape), dtype(dtype), layout(layout) {} + TensorAttributes(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout, Tile tile = std::array{32, 32}) : + storage(storage), shape(shape), dtype(dtype), layout(layout), tile(tile) {} TensorAttributes() : - shape(std::array{0xff, 0xff, 0xff, 0xff}), dtype(DataType::INVALID), layout(Layout::INVALID) {} + shape(std::array{0xff, 0xff, 0xff, 0xff}), dtype(DataType::INVALID), layout(Layout::INVALID), tile(std::array{32, 32}) {} ~TensorAttributes() = default; // Use these functions to manage the main_thread_ref_count for a tensor attr instance. @@ -118,8 +120,8 @@ struct Tensor { workers(std::vector{}), deallocate_through_destructor(false) {} - Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout); - Tensor(const Storage storage, const tt::tt_metal::LegacyShape shape, DataType dtype, Layout layout); + Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout, const std::optional& tile = std::nullopt); + Tensor(const Storage storage, const tt::tt_metal::LegacyShape shape, DataType dtype, Layout layout, const std::optional& tile = std::nullopt); // Constructor to initialize unpopulated tensor with workers and storage specified. Use this when creating tensor // handles in async mode. @@ -293,6 +295,7 @@ struct Tensor { const ttnn::Shape &get_shape() const; const DataType &get_dtype() const; const Layout &get_layout() const; + const Tile &get_tile() const; // ====================================================================================== // Non-Blocking Getters. Query attributes directly, without waiting for worker completion @@ -302,6 +305,7 @@ struct Tensor { inline const ttnn::Shape &shape() const { return this->tensor_attributes->shape; }; inline const DataType &dtype() const { return this->tensor_attributes->dtype; }; inline const Layout &layout() const { return this->tensor_attributes->layout; }; + inline const Tile &tile() const { return this->tensor_attributes->tile; }; // ====================================================================================== // Setters @@ -310,6 +314,7 @@ struct Tensor { inline void set_shape(const ttnn::Shape &shape) { this->tensor_attributes->shape = shape; } inline void set_dtype(const DataType &dtype) { this->tensor_attributes->dtype = dtype; } inline void set_layout(const Layout &layout) { this->tensor_attributes->layout = layout; } + inline void set_tile(const Tile &tile) { this->tensor_attributes->tile = tile; } // ====================================================================================== // Extra Helper Functions // ====================================================================================== @@ -367,9 +372,9 @@ struct Tensor { // Size in bytes of a single element held in tensor uint32_t element_size() const; - static constexpr auto attribute_names = std::forward_as_tuple("storage", "shape", "dtype", "layout"); + static constexpr auto attribute_names = std::forward_as_tuple("storage", "shape", "dtype", "layout", "tile"); const auto attribute_values() const { - return std::forward_as_tuple(this->tensor_attributes->storage, this->tensor_attributes->shape, this->tensor_attributes->dtype, this->tensor_attributes->layout); + return std::forward_as_tuple(this->tensor_attributes->storage, this->tensor_attributes->shape, this->tensor_attributes->dtype, this->tensor_attributes->layout, this->tensor_attributes->tile); } std::vector host_page_ordering(); @@ -396,15 +401,17 @@ Tensor create_device_tensor( DataType dtype, Layout layout, Device *device, - const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); + const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + const std::optional& tile = std::nullopt); static Tensor create_device_tensor( const ttnn::Shape &shape, DataType dtype, Layout layout, Device *device, - const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { - return create_device_tensor(shape.value, dtype, layout, device, memory_config); + const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + const std::optional& tile = std::nullopt) { + return create_device_tensor(shape.value, dtype, layout, device, memory_config, tile); } // template @@ -432,13 +439,15 @@ Tensor allocate_tensor_on_device( DataType data_type, Layout layout, Device *device, - const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); + const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + const std::optional& tile = std::nullopt); Tensor allocate_tensor_on_device( const ttnn::Shape &shape, DataType data_type, Layout layout, MeshDevice *mesh_device, - const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}); + const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + const std::optional& tile = std::nullopt); void write_tensor(Tensor host_tensor, Tensor device_tensor, uint8_t cq_id = ttnn::DefaultQueueId); // Maps a tensor to the set of devices in the device-mesh that the shards will be distributed across. diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index ea4c0fc753a..5dc5560a5d6 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -43,9 +43,12 @@ uint32_t element_size_bytes(DataType dtype) { } } -uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const tt::tt_metal::LegacyShape& shape) { +uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const tt::tt_metal::LegacyShape& shape, const std::optional& tile) { uint32_t W = shape[-1]; uint32_t page_size = 0; + const auto tile_HW = tile.has_value() ? tile->get_tile_hw() : constants::TILE_HW; + const auto bfloat8b_tile_HW = tile.has_value() ? tile_HW + 64 : constants::BFLOAT8_B_TILE_HW; + const auto bfloat4b_tile_HW = tile.has_value() ? tile_HW / 2 + 64 : constants::BFLOAT4_B_TILE_HW; switch (layout) { case Layout::ROW_MAJOR: { uint32_t size_of_element = element_size_bytes(dtype); @@ -57,24 +60,24 @@ uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, case DataType::BFLOAT16: { // Float is converted to bfloat16 before being written to device uint32_t size_of_element = element_size_bytes(DataType::BFLOAT16); - page_size = constants::TILE_HW * size_of_element; + page_size = tile_HW * size_of_element; } break; case DataType::FLOAT32: { uint32_t size_of_element = element_size_bytes(DataType::FLOAT32); - page_size = constants::TILE_HW * size_of_element; + page_size = tile_HW * size_of_element; } break; case DataType::UINT32: case DataType::INT32: case DataType::UINT16: case DataType::UINT8:{ uint32_t size_of_element = element_size_bytes(dtype); - page_size = constants::TILE_HW * size_of_element; + page_size = tile_HW * size_of_element; } break; case DataType::BFLOAT4_B: { - page_size = constants::BFLOAT4_B_TILE_HW; + page_size = bfloat4b_tile_HW; } break; case DataType::BFLOAT8_B: { - page_size = constants::BFLOAT8_B_TILE_HW; + page_size = bfloat8b_tile_HW; } break; default: TT_ASSERT(false && "Unsupported data type!"); } @@ -86,13 +89,16 @@ uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, return page_size; } -std::array get_sharded_page_shape(Layout layout, DataType dtype, std::array shard_shape) { +std::array get_sharded_page_shape(Layout layout, DataType dtype, std::array shard_shape, const std::optional& tile) { // Physical limitation in FD for now switch (layout) { case Layout::ROW_MAJOR: // TODO: Explore valid page shapes other than 1,W return {1, shard_shape[1]}; - case Layout::TILE: return {constants::TILE_HEIGHT, constants::TILE_WIDTH}; + case Layout::TILE: { + auto tile_shape = tile.value_or(Tile{{constants::TILE_HEIGHT, constants::TILE_WIDTH}}).get_tile_shape(); + return {tile_shape[0], tile_shape[1]}; + } default: TT_THROW("Unsupported layout to write to device"); } } @@ -102,7 +108,8 @@ void validate_sharded_buffer_allocation( Layout layout, DataType data_type, const ShardSpecBuffer& shard_params, - const MemoryConfig& memory_config) { + const MemoryConfig& memory_config, + const std::optional& tile) { const auto& shard_spec = memory_config.shard_spec.value(); const auto& shard_shape = shard_spec.shape; @@ -163,8 +170,9 @@ void validate_sharded_buffer_allocation( TT_THROW("Unsupported sharding scheme"); } if (layout == Layout::TILE) { + auto tile_shape = tile.value_or(Tile{{constants::TILE_HEIGHT, constants::TILE_WIDTH}}).get_tile_shape(); TT_FATAL( - (shard_shape[0] % constants::TILE_HEIGHT == 0 && shard_shape[1] % constants::TILE_WIDTH == 0), + (shard_shape[0] % tile_shape[0] == 0 && shard_shape[1] % tile_shape[1] == 0), "Shard shape must be tile sized"); } else if (layout == Layout::ROW_MAJOR) { TT_FATAL(shard_shape[1] * tensor_impl::element_size_bytes(data_type) % sizeof(uint32_t) == 0, "Error"); @@ -179,8 +187,9 @@ DeviceBuffer allocate_interleaved_buffer_on_device( const tt::tt_metal::LegacyShape& shape, DataType data_type, Layout layout, - const MemoryConfig& memory_config) { - uint32_t page_size = get_page_size(data_type, layout, buffer_size_bytes, shape); + const MemoryConfig& memory_config, + const std::optional& tile) { + uint32_t page_size = get_page_size(data_type, layout, buffer_size_bytes, shape, tile); return std::make_shared(device, buffer_size_bytes, page_size, memory_config.buffer_type); } @@ -196,10 +205,11 @@ DeviceBuffer allocate_sharded_buffer_on_device( DataType data_type, Layout layout, const ShardSpecBuffer& shard_params, - const MemoryConfig& memory_config) { - validate_sharded_buffer_allocation(shape, layout, data_type, shard_params, memory_config); + const MemoryConfig& memory_config, + const std::optional& tile) { + validate_sharded_buffer_allocation(shape, layout, data_type, shard_params, memory_config, tile); const auto& page_shape = shard_params.page_shape; - uint32_t page_size = get_page_size(data_type, layout, buffer_size_bytes, page_shape); + uint32_t page_size = get_page_size(data_type, layout, buffer_size_bytes, page_shape, tile); return std::make_shared( device, buffer_size_bytes, page_size, memory_config.buffer_type, memory_config.memory_layout, shard_params); @@ -214,16 +224,17 @@ DeviceBuffer allocate_buffer_on_device( DataType data_type, Layout layout, const MemoryConfig& memory_config, - const std::optional& shard_spec) { + const std::optional& shard_spec, + const std::optional& tile) { if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { return detail::allocate_interleaved_buffer_on_device( - buffer_size_bytes, device, shape, data_type, layout, memory_config); + buffer_size_bytes, device, shape, data_type, layout, memory_config, tile); } else if (memory_config.memory_layout == tt::tt_metal::TensorMemoryLayout::SINGLE_BANK) { return detail::allocate_contiguous_buffer_on_device(buffer_size_bytes, device, memory_config); } else { TT_ASSERT(memory_config.is_sharded(), "Incorrect Memory Layout"); return detail::allocate_sharded_buffer_on_device( - buffer_size_bytes, device, shape, data_type, layout, shard_spec.value(), memory_config); + buffer_size_bytes, device, shape, data_type, layout, shard_spec.value(), memory_config, tile); } } @@ -284,7 +295,7 @@ Tensor pad_bfloat8_b( unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()) + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tensor.get_tile()) .pad(output_tensor_shape, input_tensor_start, pad_value); // Convert back to BFLOAT8_B @@ -296,7 +307,8 @@ Tensor pad_bfloat8_b( std::move(OwnedStorage{std::move(output_uint32_buffer)}), float_tensor.get_legacy_shape(), DataType::BFLOAT8_B, - tensor.get_layout()); + tensor.get_layout(), + tensor.get_tile()); } Tensor unpad_bfloat8_b(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_tensor_start, const tt::tt_metal::LegacyShape& output_tensor_end) { @@ -308,7 +320,7 @@ Tensor unpad_bfloat8_b(const Tensor& tensor, const tt::tt_metal::LegacyShape& ou unpack_bfp8_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()) + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tensor.get_tile()) .unpad(output_tensor_start, output_tensor_end); // Convert back to BFLOAT8_B @@ -320,7 +332,8 @@ Tensor unpad_bfloat8_b(const Tensor& tensor, const tt::tt_metal::LegacyShape& ou std::move(OwnedStorage{std::move(output_uint32_buffer)}), float_tensor.get_legacy_shape(), DataType::BFLOAT8_B, - tensor.get_layout()); + tensor.get_layout(), + tensor.get_tile()); } Tensor pad_bfloat4_b( @@ -333,7 +346,7 @@ Tensor pad_bfloat4_b( unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()) + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tensor.get_tile()) .pad(output_tensor_shape, input_tensor_start, pad_value); // Convert back to BFLOAT4_B @@ -345,7 +358,8 @@ Tensor pad_bfloat4_b( std::move(OwnedStorage{std::move(output_uint32_buffer)}), float_tensor.get_legacy_shape(), DataType::BFLOAT4_B, - tensor.get_layout()); + tensor.get_layout(), + tensor.get_tile()); } Tensor unpad_bfloat4_b(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_tensor_start, const tt::tt_metal::LegacyShape& output_tensor_end) { @@ -357,7 +371,7 @@ Tensor unpad_bfloat4_b(const Tensor& tensor, const tt::tt_metal::LegacyShape& ou unpack_bfp4_tiles_into_float_vec(input_packed_data, /*row_major_output=*/false, /*is_exp_a=*/false); auto input_float_buffer = owned_buffer::create(std::move(input_float_data)); auto float_tensor = - Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout()) + Tensor(OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, tensor.get_layout(), tensor.get_tile()) .unpad(output_tensor_start, output_tensor_end); // Convert back to BFLOAT4_B @@ -369,7 +383,8 @@ Tensor unpad_bfloat4_b(const Tensor& tensor, const tt::tt_metal::LegacyShape& ou std::move(OwnedStorage{std::move(output_uint32_buffer)}), float_tensor.get_legacy_shape(), DataType::BFLOAT4_B, - tensor.get_layout()); + tensor.get_layout(), + tensor.get_tile()); } // ====================================================================================== @@ -574,7 +589,8 @@ std::string to_string(const Tensor& tensor, std::optional original_dty OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, - tensor.get_layout()); + tensor.get_layout(), + tensor.get_tile()); return to_string(float_tensor, tensor.get_dtype()); } @@ -588,7 +604,8 @@ std::string to_string(const Tensor& tensor, std::optional original_dty OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, - tensor.get_layout()); + tensor.get_layout(), + tensor.get_tile()); return to_string(float_tensor, tensor.get_dtype()); } const auto buffer = owned_buffer::get_as(storage.buffer); @@ -656,7 +673,7 @@ Tensor to_host_helper(const Tensor& tensor, bool blocking = true, uint8_t cq_id read_data_from_device_buffer(device_buffer, data_vec); } auto output_buffer = owned_buffer::create(std::move(data_vec)); - return Tensor(OwnedStorage{output_buffer}, tensor.get_legacy_shape(), tensor.get_dtype(), tensor.get_layout()); + return Tensor(OwnedStorage{output_buffer}, tensor.get_legacy_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); } template @@ -673,6 +690,7 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { host_tensor.set_shape(tensor.get_shape()); host_tensor.set_dtype(tensor.get_dtype()); host_tensor.set_layout(tensor.get_layout()); + host_tensor.set_tile(tensor.get_tile()); insert_buffer_and_shape_for_device(device, shard, host_tensor, device_index); } return host_tensor; @@ -717,7 +735,7 @@ Tensor to_host_sharded(const Tensor& tensor) { ::detail::ReadFromBuffer(*device_buffer, device_data, true); auto data_vec = unpack_uint32_vec(device_data); auto output_buffer = owned_buffer::create(std::move(data_vec)); - return Tensor(OwnedStorage{output_buffer}, tensor.get_legacy_shape(), tensor.get_dtype(), tensor.get_layout()); + return Tensor(OwnedStorage{output_buffer}, tensor.get_legacy_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); } template Tensor to_host_sharded(const Tensor& tensor); @@ -786,13 +804,14 @@ DeviceBuffer initialize_data_on_device( Layout layout, const MemoryConfig& memory_config, const std::optional& shard_spec, + const std::optional& tile, std::optional> queue = std::nullopt) { ZoneScoped; TT_ASSERT(device != nullptr); auto packed_size_in_bytes = packed_buffer_size_bytes(data_to_write.size()); auto device_buffer = - allocate_buffer_on_device(packed_size_in_bytes, device, shape, data_type, layout, memory_config, shard_spec); + allocate_buffer_on_device(packed_size_in_bytes, device, shape, data_type, layout, memory_config, shard_spec, tile); const char* TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { write_data_to_device_buffer( @@ -812,9 +831,10 @@ DeviceBuffer to_device_buffer( Layout layout, const MemoryConfig& memory_config, const std::optional& shard_spec, + const std::optional& tile, std::optional> queue) { return std::visit( - [&device, &shape, &data_type, &layout, memory_config, shard_spec](auto&& storage) -> DeviceBuffer { + [&device, &shape, &data_type, &layout, &tile, memory_config, shard_spec](auto&& storage) -> DeviceBuffer { using StorageType = std::decay_t; if (memory_config.is_sharded()) { TT_ASSERT(shard_spec.has_value(), "If sharded must provide shard_spec"); @@ -827,12 +847,13 @@ DeviceBuffer to_device_buffer( compute_buffer_size(shape, data_type), data_to_write.size()); if (layout == Layout::TILE) { + auto tile_shape = tile.value_or(Tile{{constants::TILE_HEIGHT, constants::TILE_WIDTH}}).get_tile_shape(); TT_ASSERT( - (shape[-2] % tt::constants::TILE_HEIGHT == 0 && shape[-1] % tt::constants::TILE_WIDTH == 0), + (shape[-2] % tile_shape[0] == 0 && shape[-1] % tile_shape[1] == 0), "Tensor shape incompatible for specified layout"); } return initialize_data_on_device( - data_to_write, device, shape, data_type, layout, memory_config, shard_spec); + data_to_write, device, shape, data_type, layout, memory_config, shard_spec, tile); } else if constexpr (std::is_same_v) { TT_THROW("Device storage doesn't support to_device_buffer"); } else if constexpr (std::is_same_v) { @@ -866,10 +887,11 @@ Tensor to_device( auto shape = tensor.get_legacy_shape(); auto data_type = tensor.get_dtype(); auto layout = tensor.get_layout(); + auto tile = tensor.get_tile(); std::optional shard_spec_buffer_opt = std::nullopt; if (memory_config.is_sharded()) { - auto page_shape = get_sharded_page_shape(layout, data_type, memory_config.shard_spec.value().shape); + auto page_shape = get_sharded_page_shape(layout, data_type, memory_config.shard_spec.value().shape, tile); auto width = shape[-1]; auto other_dims = 1; @@ -882,8 +904,8 @@ Tensor to_device( } auto device_buffer = tensor_impl::to_device_buffer( - tensor.get_storage(), target_device, shape, data_type, layout, memory_config, shard_spec_buffer_opt, queue); - return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout); + tensor.get_storage(), target_device, shape, data_type, layout, memory_config, shard_spec_buffer_opt, tile, queue); + return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, tile); } template Tensor to_device( @@ -947,18 +969,19 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { auto shape = tensor.get_legacy_shape(); auto source_layout = tensor.get_layout(); - auto convert = [&shape, source_layout, target_layout](const auto& input_data) -> std::vector { + auto tile = tensor.tile(); + auto convert = [tile, &shape, source_layout, target_layout](const auto& input_data) -> std::vector { switch (source_layout) { case Layout::ROW_MAJOR: if (target_layout == Layout::TILE) { - return convert_layout_row_major_to_tile(shape, input_data); + return convert_layout_row_major_to_tile(shape, tile, input_data); } else { TT_THROW("Unsupported layout conversion"); } break; case Layout::TILE: if (target_layout == Layout::ROW_MAJOR) { - return convert_layout_tile_to_row_major(shape, input_data); + return convert_layout_tile_to_row_major(shape, tile, input_data); } else { TT_THROW("Unsupported layout conversion"); } @@ -1002,9 +1025,9 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { [&tensor, &target_layout](auto&& storage) -> Tensor { using StorageType = std::decay_t; if constexpr (std::is_same_v) { - return Tensor(storage, tensor.get_legacy_shape(), tensor.get_dtype(), target_layout); + return Tensor(storage, tensor.get_legacy_shape(), tensor.get_dtype(), target_layout, tensor.get_tile()); } else if constexpr (std::is_same_v) { - return Tensor(storage, tensor.get_legacy_shape(), tensor.get_dtype(), target_layout); + return Tensor(storage, tensor.get_legacy_shape(), tensor.get_dtype(), target_layout, tensor.get_tile()); } else { raise_unsupported_storage(); } @@ -1077,7 +1100,8 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, - tensor.get_layout()) + tensor.get_layout(), + tensor.get_tile()) .to(target_layout); // Convert back to BFLOAT8_B @@ -1091,7 +1115,8 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { std::move(MultiDeviceHostStorage{storage.strategy, output_buffers, storage.shapes}), tensor.get_legacy_shape(), bfloat_enum::value, - target_layout); + target_layout, + tensor.get_tile()); } else { // Convert to FLOAT32 tensor and change layout @@ -1103,7 +1128,8 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { OwnedStorage{input_float_buffer}, tensor.get_legacy_shape(), DataType::FLOAT32, - tensor.get_layout()) + tensor.get_layout(), + tensor.get_tile()) .to(target_layout); // Convert back to BFLOAT @@ -1115,7 +1141,8 @@ Tensor to_layout_bfloat(const Tensor& tensor, Layout target_layout) { std::move(OwnedStorage{std::move(output_uint32_buffer)}), tensor.get_legacy_shape(), bfloat_enum::value, - target_layout); + target_layout, + tensor.get_tile()); } }, tensor.get_storage()); @@ -1222,7 +1249,7 @@ Tensor pad(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_shape, } }, tensor.get_storage()); - return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout()); + return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); } template Tensor pad( @@ -1310,7 +1337,7 @@ Tensor unpad(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_tenso } }, tensor.get_storage()); - return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout()); + return Tensor(OwnedStorage{output_buffer}, output_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); } template Tensor unpad(const Tensor& tensor, const tt::tt_metal::LegacyShape& output_tensor_start, const tt::tt_metal::LegacyShape& output_tensor_end); @@ -1345,7 +1372,7 @@ Tensor extract_shard(const Tensor& tensor, const uint32_t& core_id) { auto unpacked_data = tensor_impl::unpack_uint32_vec(device_data); auto output_buffer = owned_buffer::create(std::move(unpacked_data)); - return Tensor(OwnedStorage{output_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()); + return Tensor(OwnedStorage{output_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); } template Tensor extract_shard(const Tensor& tensor, const uint32_t& core_id); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index 422d4308377..1a36c7ca8a9 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -24,7 +24,7 @@ namespace tt_metal { namespace tensor_impl { -std::array get_sharded_page_shape(Layout layout, DataType dtype, std::array shard_shape); +std::array get_sharded_page_shape(Layout layout, DataType dtype, std::array shard_shape, const std::optional& tile); // ----------------------------------------------------------------------------------------------------------------------------------------------- // =============================================================================================================================================== @@ -212,18 +212,23 @@ static std::vector to_vector(const tt::tt_metal::LegacyShape& shape) { } // namespace detail template typename BufferType> -inline std::vector convert_layout_row_major_to_tile(const tt::tt_metal::LegacyShape& shape, const BufferType& data_to_convert) { +inline std::vector convert_layout_row_major_to_tile(const tt::tt_metal::LegacyShape& shape, const Tile& tile, const BufferType& data_to_convert) { TT_FATAL( - (shape[-2] % tt::constants::TILE_HEIGHT == 0 && shape[-1] % tt::constants::TILE_WIDTH == 0), + (shape[-2] % tile.get_tile_shape()[0] == 0 && shape[-1] % tile.get_tile_shape()[1] == 0), "Unsupported shape for tensor conversion"); + + auto tile_shape = std::vector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; + auto face_shape = std::vector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; return convert_layout( - data_to_convert, detail::to_vector(shape), TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED32_4FACES); + data_to_convert, detail::to_vector(shape), TensorLayout::LIN_ROW_MAJOR, TensorLayout::TILED_NFACES, tile_shape, face_shape); } template typename BufferType> -inline std::vector convert_layout_tile_to_row_major(const tt::tt_metal::LegacyShape& shape, const BufferType& data_to_convert) { +inline std::vector convert_layout_tile_to_row_major(const tt::tt_metal::LegacyShape& shape, const Tile& tile, const BufferType& data_to_convert) { + auto tile_shape = std::vector{ tile.get_tile_shape()[0], tile.get_tile_shape()[1] }; + auto face_shape = std::vector{ tile.get_face_shape()[0], tile.get_face_shape()[1] }; return convert_layout( - data_to_convert, detail::to_vector(shape), TensorLayout::TILED32_4FACES, TensorLayout::LIN_ROW_MAJOR); + data_to_convert, detail::to_vector(shape), TensorLayout::TILED_NFACES, TensorLayout::LIN_ROW_MAJOR, tile_shape, face_shape); } // ====================================================================================== @@ -235,7 +240,8 @@ void validate_sharded_buffer_allocation( Layout layout, DataType data_type, const ShardSpecBuffer& shard_params, - const MemoryConfig& memory_config); + const MemoryConfig& memory_config, + const std::optional& tile = std::nullopt); // ----------------------------------------------------------------------------------------------------------------------------------------------- // =============================================================================================================================================== // High Level APIs @@ -246,7 +252,7 @@ void validate_sharded_buffer_allocation( // Data reader, writer, and initializers // ====================================================================================== -uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const tt::tt_metal::LegacyShape& shape); +uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const tt::tt_metal::LegacyShape& shape, const std::optional& tile = std::nullopt); DeviceBuffer allocate_buffer_on_device( size_t buffer_size_bytes, @@ -255,7 +261,8 @@ DeviceBuffer allocate_buffer_on_device( DataType data_type, Layout layout, const MemoryConfig& memory_config, - const std::optional& shard_spec = std::nullopt); + const std::optional& shard_spec = std::nullopt, + const std::optional& tile = std::nullopt); template inline void read_data_from_device_buffer( diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 3d2a160c13b..25e2cc60210 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -97,6 +97,7 @@ Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers device_tensor.set_shape(input_tensor.get_shape()); device_tensor.set_dtype(input_tensor.get_dtype()); device_tensor.set_layout(input_tensor.get_layout()); + device_tensor.set_tile(input_tensor.get_tile()); device_tensor.tensor_attributes->metadata_populated = true; } }); @@ -138,6 +139,7 @@ Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { host_tensor.set_shape(input_tensor.get_shape()); host_tensor.set_dtype(input_tensor.get_dtype()); host_tensor.set_layout(input_tensor.get_layout()); + host_tensor.set_tile(input_tensor.get_tile()); host_tensor.tensor_attributes->metadata_populated = true; } }); @@ -226,6 +228,7 @@ Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, MeshDevice* m tensor_modified_layout.set_shape(input_tensor.get_shape()); tensor_modified_layout.set_dtype(input_tensor.get_dtype()); tensor_modified_layout.set_layout(target_layout); + tensor_modified_layout.set_tile(input_tensor.get_tile()); tensor_modified_layout.tensor_attributes->metadata_populated = true; }; }); diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 1f8750f9ba8..6a2c7574230 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -743,7 +743,7 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) using BorrowedStorageType = std::vector>; auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end())); owned_tensor = - Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout()); + Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout(), tensor.get_tile()); }, borrowed_buffer); return owned_tensor; diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 0192b2b1673..4e19691062b 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -128,6 +128,7 @@ def manage_config(name, value): CoreRangeSet, CoreRange, CoreCoord, + Tile, Layout, ROW_MAJOR_LAYOUT, TILE_LAYOUT, diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index c3d13f6bfbb..88d9ba523d4 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -172,6 +172,7 @@ def from_torch( tensor: "torch.Tensor", dtype: Optional[ttnn.DataType] = None, *, + tile: Optional[ttnn.Tile] = None, layout: Optional[ttnn.Layout] = ttnn.ROW_MAJOR_LAYOUT, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, @@ -217,7 +218,10 @@ def from_torch( shards = mesh_mapper.map(tensor) tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) else: - tensor = ttnn.Tensor(tensor, dtype) + if tile is not None: + tensor = ttnn.Tensor(tensor, dtype, {}, tile) + else: + tensor = ttnn.Tensor(tensor, dtype) if layout is not None: tensor = ttnn.to_layout(tensor, layout, device=device) diff --git a/ttnn/ttnn/types.py b/ttnn/ttnn/types.py index 0f8d2c8cbf6..0c95d333381 100644 --- a/ttnn/ttnn/types.py +++ b/ttnn/ttnn/types.py @@ -38,6 +38,8 @@ TILE_SIZE = 32 +Tile = ttnn._ttnn.tensor.Tile + Shape = ttnn._ttnn.types.Shape Tensor = ttnn._ttnn.tensor.Tensor