From 300fe91838ed2621c42dbfb042b335d942e0d047 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Wed, 27 Nov 2024 21:30:09 +0000 Subject: [PATCH 01/20] #14790: add sub-height sharding support for transpose wh --- .../unit_testing/misc/test_transpose.py | 75 +++++++++++++++++++ .../transpose/device/transpose_op.cpp | 33 +++++--- .../device/transpose_program_factory.cpp | 75 ++++++++++++++----- 3 files changed, 152 insertions(+), 31 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index 489b25ba5e9..ef89d551683 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -1011,3 +1011,78 @@ def test_transpose_forge_hc(device, b, h, w, dim0, dim1): output_tensor = ttnn.to_torch(output_tensor) assert_with_pcc(torch_output_tensor, output_tensor) + + +@pytest.mark.parametrize("n", [1]) +@pytest.mark.parametrize("c", [1]) +@pytest.mark.parametrize("h", [256]) +@pytest.mark.parametrize("w", [32]) +def test_tranpose_hw_sharded_tiled_8_cores(device, n, c, h, w): + torch.manual_seed(2005) + torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16) + torch_output_tensor = torch_input_tensor.transpose(2, 3) + tt_input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=ttnn.DataType.BFLOAT16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + sharded_mem_config = ttnn.create_sharded_memory_config( + (32, 32), + core_grid=ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 6)), + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(1, 0)), + } + ), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.COL_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, sharded_mem_config) + + tt_output_tensor = ttnn.transpose(tt_input_tensor, 2, 3, memory_config=sharded_mem_config) + tt_output_tensor = ttnn.to_memory_config(tt_output_tensor, ttnn.L1_MEMORY_CONFIG) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) + + assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) + + +@pytest.mark.parametrize("n", [1]) +@pytest.mark.parametrize("c", [1]) +@pytest.mark.parametrize("h", [224]) +@pytest.mark.parametrize("w", [32]) +def test_tranpose_hw_sharded_tiled_n_cores(device, n, c, h, w): + torch.manual_seed(2005) + torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16) + torch_output_tensor = torch_input_tensor.transpose(2, 3) + tt_input_tensor = ttnn.from_torch( + torch_input_tensor, + dtype=ttnn.DataType.BFLOAT16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + + sharded_mem_config = ttnn.create_sharded_memory_config( + (32, 32), + core_grid=ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, h // 32 - 1)), + } + ), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.COL_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, sharded_mem_config) + + tt_output_tensor = ttnn.transpose(tt_input_tensor, 2, 3, memory_config=sharded_mem_config) + tt_output_tensor = ttnn.to_memory_config(tt_output_tensor, ttnn.L1_MEMORY_CONFIG) + tt_output_tensor = ttnn.from_device(tt_output_tensor) + tt_output_tensor = ttnn.to_torch(tt_output_tensor) + + assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp index 1776e28adc8..0365d92c21c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp @@ -43,22 +43,28 @@ void Transpose::validate(const std::vector& input_tensors) const { if (input_tensor.is_sharded()) { TT_FATAL(input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Error"); const auto shard_spec = input_tensor.shard_spec().value(); - TT_FATAL(shard_spec.shape[1] == W, "Error"); - TT_FATAL(shard_spec.shape[0] % H == 0, "Error"); + TT_FATAL( + (shard_spec.shape[0] % H == 0) || (H % shard_spec.shape[0] == 0), + "Only a multiple of H or a factor of H is allows for the shard height"); + TT_FATAL(shard_spec.shape[1] == W, "Only height sharding is supported"); TT_FATAL(this->output_mem_config.is_sharded(), "Error"); TT_FATAL(this->output_mem_config.memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Error"); } else { - TT_FATAL(!this->output_mem_config.is_sharded(), "Error"); + TT_FATAL(!this->output_mem_config.is_sharded(), "Interleaved inputs cannot output sharded outputs"); } } else { if (input_tensor.is_sharded()) { - TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); + TT_FATAL( + input_tensor.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + "Only height sharding is supported for transpose hc"); const auto shard_spec = input_tensor.shard_spec().value(); - TT_FATAL(shard_spec.shape[1] == W, "Error"); - TT_FATAL(this->output_mem_config.is_sharded(), "Error"); - TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); + TT_FATAL(shard_spec.shape[1] == W, "Block/Width sharding is not supported"); + TT_FATAL(this->output_mem_config.is_sharded(), "Sharded input can only output sharded tensors"); + TT_FATAL( + this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + "Only height sharding is supported"); } else { - TT_FATAL(!this->output_mem_config.is_sharded(), "Error"); + TT_FATAL(!this->output_mem_config.is_sharded(), "Interleaved inputs cannot output sharded outputs"); } } if (this->dim == TransposeOpDim::HC) { @@ -147,9 +153,14 @@ std::vector Transpose::compute_output_specs(const std::vector< if (this->dim == TransposeOpDim::WH) { const auto& input_padded_shape = input_tensor.get_padded_shape(); ShardSpec shard_spec = input_tensor.shard_spec().value(); - shard_spec.shape[0] = shard_spec.shape[0] / input_padded_shape[-2] * input_padded_shape[-1]; - shard_spec.shape[1] = input_padded_shape[-2]; - output_mem_config.shard_spec = shard_spec; + if (shard_spec.shape[0] >= input_padded_shape[-2]) { + shard_spec.shape[0] = shard_spec.shape[0] / input_padded_shape[-2] * input_padded_shape[-1]; + shard_spec.shape[1] = input_padded_shape[-2]; + output_mem_config.shard_spec = shard_spec; + } else { + std::swap(shard_spec.shape[0], shard_spec.shape[1]); + output_mem_config.shard_spec = shard_spec; + } } else if (this->dim == TransposeOpDim::HC) { output_mem_config.shard_spec = input_tensor.shard_spec().value(); } else { diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp index 2dbdcf7dc8a..02bc49406d8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp @@ -1776,8 +1776,9 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, uint32_t dst_single_tile_size = tt::tt_metal::detail::TileSize(dst_cb_data_format); tt::tt_metal::Buffer* src0_buffer = a.buffer(); - - int32_t num_tiles = a.volume() / TILE_HW; + const auto tile = a.get_tensor_spec().tile(); + const uint32_t tile_hw = tile.get_tile_hw(); + int32_t num_tiles = a.volume() / tile_hw; tt::tt_metal::Device* device = a.device(); @@ -1793,7 +1794,7 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, auto& all_cores = shard_spec.grid; uint32_t num_cores = all_cores.num_cores(); - uint32_t num_tiles_per_shard = shard_spec.numel() / TILE_HW; + uint32_t num_tiles_per_shard = shard_spec.numel() / tile_hw; tt::tt_metal::LegacyShape output_shape = output.get_legacy_shape(); @@ -1848,11 +1849,22 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, total_cores, tt::tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_compile_time_args}); - uint32_t Wt = shard_spec.shape[1] / TILE_WIDTH; - uint32_t Ht = a.get_legacy_shape()[-2] / TILE_HEIGHT; - uint32_t HtWt = Ht * Wt; - uint32_t N = shard_spec.shape[0] / a.get_legacy_shape()[-2]; - uint32_t NHtWt = N * HtWt; + auto padded_shape = a.get_padded_shape(); + auto shard_shape = shard_spec.shape; + + uint32_t H = padded_shape[2], W = padded_shape[3]; + uint32_t Hs = shard_shape[0], Ws = shard_shape[1]; + + uint32_t Hts = Hs / tile.tile_shape[0]; + uint32_t Wts = Ws / tile.tile_shape[1]; + + uint32_t Ht = H / tile.tile_shape[0]; + uint32_t Ht_per_shard = std::min(Ht, Hts); + + uint32_t num_hw_blocks_per_shard = Hts > Ht ? Hts / Ht : 1; + + uint32_t HtWt_tile_size = Ht_per_shard * Wts; + uint32_t num_blocks = num_hw_blocks_per_shard * HtWt_tile_size; auto bbox = all_cores.bounding_box(); std::vector cores = @@ -1862,13 +1874,17 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, std::vector> unary_compute_args = {cores.size(), std::vector(5)}; std::vector> unary_writer_args = {cores.size(), std::vector(1)}; std::fill( - unary_reader_args.begin(), unary_reader_args.begin() + all_cores.num_cores(), std::vector{NHtWt}); + unary_reader_args.begin(), + unary_reader_args.begin() + all_cores.num_cores(), + std::vector{num_blocks}); std::fill( unary_compute_args.begin(), unary_compute_args.begin() + all_cores.num_cores(), - std::vector{NHtWt, HtWt, N, Ht, Wt}); + std::vector{num_blocks, HtWt_tile_size, num_hw_blocks_per_shard, Ht_per_shard, Wts}); std::fill( - unary_writer_args.begin(), unary_writer_args.begin() + all_cores.num_cores(), std::vector{NHtWt}); + unary_writer_args.begin(), + unary_writer_args.begin() + all_cores.num_cores(), + std::vector{num_blocks}); tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, cores, unary_reader_args); tt::tt_metal::SetRuntimeArgs(program, compute_kernel_id, cores, unary_compute_args); @@ -1899,7 +1915,11 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, auto shard_spec = src_tensor.shard_spec().value(); - uint32_t num_tiles_per_shard = shard_spec.numel() / TILE_HW; + const auto tile = src_tensor.get_tensor_spec().tile(); + const uint32_t tile_hw = tile.get_tile_hw(); + int32_t num_tiles = src_tensor.volume() / tile_hw; + + uint32_t num_tiles_per_shard = shard_spec.numel() / tile_hw; if (src0_sharded) { UpdateDynamicCircularBufferAddressAndTotalSize( @@ -1911,11 +1931,22 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, program, cb_output, *dst_buffer, num_tiles_per_shard * dst_single_tile_size); } - uint32_t Wt = shard_spec.shape[1] / TILE_WIDTH; - uint32_t Ht = src_tensor.get_legacy_shape()[-2] / TILE_HEIGHT; - uint32_t HtWt = Ht * Wt; - uint32_t N = shard_spec.shape[0] / src_tensor.get_legacy_shape()[-2]; - uint32_t NHtWt = N * HtWt; + auto padded_shape = src_tensor.get_padded_shape(); + auto shard_shape = shard_spec.shape; + + uint32_t H = padded_shape[2], W = padded_shape[3]; + uint32_t Hs = shard_shape[0], Ws = shard_shape[1]; + + uint32_t Hts = Hs / tile.tile_shape[0]; + uint32_t Wts = Ws / tile.tile_shape[1]; + + uint32_t Ht = H / tile.tile_shape[0]; + uint32_t Ht_per_shard = std::min(Ht, Hts); + + uint32_t num_hw_blocks_per_shard = Hts > Ht ? Hts / Ht : 1; + + uint32_t HtWt_tile_size = Ht_per_shard * Wts; + uint32_t num_blocks = num_hw_blocks_per_shard * HtWt_tile_size; const auto& all_cores = shard_spec.grid; bool row_major = shard_spec.orientation == ShardOrientation::ROW_MAJOR; @@ -1927,13 +1958,17 @@ operation::ProgramWithCallbacks transpose_wh_multi_core_sharded(const Tensor& a, std::vector> unary_compute_args = {cores.size(), std::vector(5)}; std::vector> unary_writer_args = {cores.size(), std::vector(1)}; std::fill( - unary_reader_args.begin(), unary_reader_args.begin() + all_cores.num_cores(), std::vector{NHtWt}); + unary_reader_args.begin(), + unary_reader_args.begin() + all_cores.num_cores(), + std::vector{num_blocks}); std::fill( unary_compute_args.begin(), unary_compute_args.begin() + all_cores.num_cores(), - std::vector{NHtWt, HtWt, N, Ht, Wt}); + std::vector{num_blocks, HtWt_tile_size, num_hw_blocks_per_shard, Ht_per_shard, Wts}); std::fill( - unary_writer_args.begin(), unary_writer_args.begin() + all_cores.num_cores(), std::vector{NHtWt}); + unary_writer_args.begin(), + unary_writer_args.begin() + all_cores.num_cores(), + std::vector{num_blocks}); tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, cores, unary_reader_args); tt::tt_metal::SetRuntimeArgs(program, compute_kernel_id, cores, unary_compute_args); From b21dc6f964ea6cb562cc4bce53dd562b8ce846cc Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Fri, 29 Nov 2024 07:18:22 +0000 Subject: [PATCH 02/20] #0: correct memory layout --- .../operations/data_movement/transpose/device/transpose_op.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp index 0365d92c21c..05bc7b3f6cf 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp @@ -160,6 +160,7 @@ std::vector Transpose::compute_output_specs(const std::vector< } else { std::swap(shard_spec.shape[0], shard_spec.shape[1]); output_mem_config.shard_spec = shard_spec; + output_mem_config.memory_layout = TensorMemoryLayout::BLOCK_SHARDED; } } else if (this->dim == TransposeOpDim::HC) { output_mem_config.shard_spec = input_tensor.shard_spec().value(); From a833cb976437196e1daacad8c99631898c53e02d Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Fri, 29 Nov 2024 07:32:34 +0000 Subject: [PATCH 03/20] #0: double buffer permute kernel --- .../device/kernels/dataflow/reader_permute_interleaved_rm.cpp | 1 - .../data_movement/permute/device/permute_program_factory.cpp | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp index 73241cb9703..b5ffc12cf7d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp @@ -21,7 +21,6 @@ void kernel_main() { uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); noc_async_read_page(i, s0, src_buffer_l1_addr); noc_async_read_barrier(); - volatile tt_l1_ptr uint16_t* out_stick = reinterpret_cast(src_buffer_l1_addr); cb_push_back(tt::CBIndex::c_0, 1); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 29f6065cb5b..6ae2596bcf5 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -55,7 +55,7 @@ PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::Sin tt::tt_metal::Device* device = input_tensor.device(); uint32_t src0_cb_index = tt::CBIndex::c_0; - uint32_t num_input_pages_to_read = 1; + uint32_t num_input_pages_to_read = 2; CoreRange core({0, 0}, {0, 0}); tt::tt_metal::CircularBufferConfig cb_src0_config = From cabfcc9c3e418cad5363a89046573e492e14c65f Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 3 Dec 2024 17:47:27 +0000 Subject: [PATCH 04/20] #15165: add N-d RM permute support that moves width --- .../unit_tests/operations/test_permute.py | 30 +++ ...r_permute_interleaved_rm_width_permute.cpp | 87 +++++++++ ...r_permute_interleaved_rm_width_permute.cpp | 174 ++++++++++++++++++ .../device/permute_device_operation.cpp | 9 +- .../device/permute_device_operation.hpp | 23 ++- .../device/permute_program_factory.cpp | 131 ++++++++++++- 6 files changed, 447 insertions(+), 7 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 40a57515f56..186f1e2f4c5 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -7,6 +7,7 @@ import torch import ttnn +import itertools from tests.ttnn.utils_for_testing import assert_with_pcc from models.utility_functions import is_blackhole @@ -171,3 +172,32 @@ def test_permute_pad_value(device, pad_value): assert ttnn.to_torch(a) == float("-inf") tt_output = ttnn.to_torch(tt_output) assert_with_pcc(torch_output, tt_output, 0.9999) + + +def generate_permutations(N): + """ + Generator function that yields all permutations of tuples with values 0 to N-1. + + :param N: The number defining the range of values (0 to N-1). + :yield: Tuples representing each permutation. + """ + for perm in itertools.permutations(range(N)): + yield perm + + +@pytest.mark.parametrize("shape", [(7, 7, 7, 7, 7)]) +@pytest.mark.parametrize("perm", generate_permutations(5)) +@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) +def test_permute_5d_width(shape, perm, memory_config, device): + torch.manual_seed(2005) + input_a = torch.randn(shape) + # print(input_a) + torch_output = torch.permute(input_a, perm) + + tt_input = ttnn.from_torch( + input_a, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, memory_config=memory_config + ) + + tt_output = ttnn.permute(tt_input, perm) + tt_output = ttnn.to_torch(tt_output) + assert_with_pcc(torch_output, tt_output, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp new file mode 100644 index 00000000000..0c060289dab --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp @@ -0,0 +1,87 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "debug/dprint.h" + +inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { + DPRINT << BF16(*ptr) << " "; + } + DPRINT << ENDL(); + } + DPRINT << ENDL(); +} + +void kernel_main() { + constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(0); + constexpr uint32_t N = get_compile_time_arg_val(1); + constexpr uint32_t page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_rows = get_compile_time_arg_val(3); + constexpr uint32_t x_dim = get_compile_time_arg_val(4); + + const uint32_t src_addr = get_arg_val(0); + const DataFormat data_format = get_dataformat(tt::CBIndex::c_0); + + uint32_t input_shape[N], src_strides[N]; + for (uint32_t i = 1; i <= N; i++) { + input_shape[i - 1] = get_arg_val(i); + src_strides[i - 1] = get_arg_val(i + N); + } + + uint32_t X = input_shape[x_dim]; + uint32_t X_stride = src_strides[x_dim]; + + // for (uint32_t i = 0; i < N; i++) { + // DPRINT << "input_shape[" << i << "] = " << input_shape[i] << " "; + // } + // DPRINT << ENDL(); + // for (uint32_t i = 0; i < N; i++) { + // DPRINT << "src_strides[" << i << "] = " << src_strides[i] << " "; + // } + // DPRINT << ENDL(); + + const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = page_size}; + + uint32_t curr_addr = src_addr; + // DPRINT << "Reading " << num_rows << " rows of " << X << " elements each" << ENDL(); + // DPRINT << "X dimension: " << x_dim << ENDL(); + for (uint32_t i = 0; i < num_rows/X; ++i) { + uint32_t idxs[N]; + idxs[N - 1] = 0; + uint32_t remainder = i; + for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension + if (d == (int32_t)x_dim) { + continue; // Skip X dimension + } + idxs[d] = remainder % input_shape[d]; + remainder /= input_shape[d]; + } + cb_reserve_back(tt::CBIndex::c_0, X); + uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); + for (uint32_t j = 0; j < X; ++j) { + idxs[x_dim] = j; + // for (uint32_t k = 0; k < N; ++k) { + // DPRINT << "idxs[" << k << "] = " << idxs[k] << " "; + // } + // Compute the address using indices and strides + uint64_t addr_offset = 0; + for (uint32_t d = 0; d < N; ++d) { + addr_offset += idxs[d] * src_strides[d]; + } + // DPRINT << "Reading page " << addr_offset << " into buffer " << ENDL(); + uint64_t src_noc_addr = get_noc_addr(addr_offset, s0); + noc_async_read(src_noc_addr, src_buffer_l1_addr, page_size); + src_buffer_l1_addr += page_size; + } + noc_async_read_barrier(); + // src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); + // print_pages(src_buffer_l1_addr, 8, X, 0); + cb_push_back(tt::CBIndex::c_0, X); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp new file mode 100644 index 00000000000..03fd8a78e9f --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp @@ -0,0 +1,174 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" + +// Function template to swap two elements in a uint32_t array +template +FORCE_INLINE void swap_elements(uint32_t (&array)[N], size_t i, size_t j) { + // Perform the swap + uint32_t temp = array[i]; + array[i] = array[j]; + array[j] = temp; +} + +inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { + DPRINT << BF16(*ptr) << " "; + } + DPRINT << ENDL(); + } + DPRINT << ENDL(); +} + +FORCE_INLINE void transpose_XW_to_WX(uint32_t input_l1_addr, uint32_t output_l1_addr, uint32_t X, uint32_t W, uint32_t element_size, uint32_t input_page_size, uint32_t output_page_size) { + volatile tt_l1_ptr uint8_t* input_ptr = reinterpret_cast(input_l1_addr); + volatile tt_l1_ptr uint8_t* output_ptr = reinterpret_cast(output_l1_addr); + // transpose from XW, where X is outer and W inner, to WX, where W is outer and X is inner + // each element is element_size bytes + // each row is W elements, and each row is separated by input_page_size bytes + // each output row is X elements, and each row is separated by output_page_size bytes + + for (uint32_t x = 0; x < X; ++x) { + for (uint32_t w = 0; w < W; ++w) { + // Compute the input and output addresses + uint32_t input_addr = x * input_page_size + w * element_size; + uint32_t output_addr = w * output_page_size + x * element_size; + // Copy the element - do we have memcpy? use this for now + for (uint32_t i = 0; i < element_size; ++i) { + output_ptr[output_addr + i] = input_ptr[input_addr + i]; + } + } + } +} + +void kernel_main() { + constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(0); + constexpr uint32_t N = get_compile_time_arg_val(1); + constexpr uint32_t output_page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_rows = get_compile_time_arg_val(3); + + constexpr uint32_t X = get_compile_time_arg_val(4); + // X_stride and x_dim along the input tensor + constexpr uint32_t X_stride = get_compile_time_arg_val(5); + constexpr uint32_t x_dim = get_compile_time_arg_val(6); + + constexpr uint32_t W = get_compile_time_arg_val(7); + // W_stride and w_dim along the output tensor + constexpr uint32_t W_stride = get_compile_time_arg_val(8); + constexpr uint32_t input_page_size = get_compile_time_arg_val(9); + constexpr uint32_t element_size_bytes = get_compile_time_arg_val(10); + + constexpr uint32_t w_dim = N - 1; + + // // DPRINT << "N = " << N << ENDL(); + // DPRINT << "output_page_size = " << output_page_size << ENDL(); + // // DPRINT << "num_rows = " << num_rows << ENDL(); + // // DPRINT << "X = " << X << ENDL(); + // // DPRINT << "X_stride = " << X_stride << ENDL(); + // // DPRINT << "x_dim = " << x_dim << ENDL(); + // // DPRINT << "W = " << W << ENDL(); + // DPRINT << "W_stride = " << W_stride << ENDL(); + // // DPRINT << "w_dim = " << w_dim << ENDL() << ENDL(); + // DPRINT << "input_page_size = " << input_page_size << ENDL(); + // DPRINT << "element_size_bytes = " << element_size_bytes << ENDL(); + + + const uint32_t dst_addr = get_arg_val(0); + + const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = output_page_size}; + + uint32_t input_shape[N], perm[N], dest_strides[N]; + for (uint32_t i = 1; i <= N; i++) { + input_shape[i - 1] = get_arg_val(i); + perm[i - 1] = get_arg_val(i + N); + dest_strides[i - 1] = get_arg_val(i + 2 * N); + } + + // after we transpose w and x, the perm is different + // perm[i] = x_dim becomes perm[i] == N - 1 as it's been in the correct position (i == N - 1) + // perm[j] = N - 1 becomes perm[j] = x_dim as x_dim has been moved to the end + // input shape[i] = X becomes input_shape[i] = W as we're transposing the X dimension + // input shape[j] = W becomes input_shape[j] = X as we're transposing the W dimension + // dest strides is slightly incorrect as it assumes W is in its final position + swap_elements(input_shape, x_dim, w_dim); + for (uint32_t i = 0; i < N; i++) { + if (perm[i] == x_dim) { + perm[i] = N - 1; + } else if (perm[i] == N - 1) { + perm[i] = x_dim; + } + } + + // for (uint32_t i = 0; i < N; i++) { + // DPRINT << "input_shape[" << i << "] = " << input_shape[i] << " "; + // } + // DPRINT << ENDL(); + // for (uint32_t i = 0; i < N; i++) { + // DPRINT << "perm[" << i << "] = " << perm[i] << " "; + // } + // DPRINT << ENDL(); + // for (uint32_t i = 0; i < N; i++) { + // DPRINT << "dest_strides[" << i << "] = " << dest_strides[i] << " "; + // } + // DPRINT << ENDL() << ENDL(); + + + uint32_t curr_addr = dst_addr; + for (uint32_t block = 0; block < num_rows/X; ++block) { + // Compute multi-dimensional index for the source block + uint32_t src_multi_idx[N]; + size_t remaining = block; + for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension + if (d == (int32_t)x_dim) { + continue; // Skip post-XW transpose w dimension + } + src_multi_idx[d] = remaining % input_shape[d]; + remaining /= input_shape[d]; + } + // Apply permutation to get destination multi-dimensional index + + src_multi_idx[N - 1] = 0; // Row dimension index + src_multi_idx[N - 1] = 0; // Row dimension index + cb_wait_front(tt::CBIndex::c_0, X); + uint32_t src_buffer_l1_addr = get_read_ptr(tt::CBIndex::c_0); + print_pages(src_buffer_l1_addr, tt::data_movement::common::round_up(), X, 0); + // Transpose the X*W*element_size block + uint32_t transposed_buffer_read_addr = get_read_ptr(tt::CBIndex::c_1); + transpose_XW_to_WX(src_buffer_l1_addr, transposed_buffer_read_addr, X, W, element_size_bytes, input_page_size, output_page_size); + print_pages(transposed_buffer_read_addr, tt::data_movement::common::round_up(), W, 0); + for (uint32_t w = 0; w < W; ++w) { + src_multi_idx[x_dim] = w; + uint32_t dest_multi_idx[N]; + for (uint32_t i = 0; i < N; ++i) { + dest_multi_idx[i] = src_multi_idx[perm[i]]; + } + for (uint32_t i = 0; i < N; i++) { + DPRINT << "src_multi_idx[" << i << "] = " << src_multi_idx[i] << " "; + } + DPRINT << ENDL(); + for (uint32_t i = 0; i < N; i++) { + DPRINT << "dest_multi_idx[" << i << "] = " << dest_multi_idx[i] << " "; + } + DPRINT << ENDL(); + + // Convert destination multi-dimensional index to linear index + uint32_t dest_linear_idx = 0; + for (uint32_t i = 0; i < N - 1; ++i) { + dest_linear_idx += dest_multi_idx[i] * dest_strides[i]; + } + + uint64_t dst_noc_addr = get_noc_addr(dest_linear_idx, s0); + noc_async_write(transposed_buffer_read_addr + w * output_page_size, dst_noc_addr, output_page_size); + noc_async_write_barrier(); + } + cb_pop_front(tt::CBIndex::c_0, X); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp index bbe319681bb..a3345096f54 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp @@ -12,7 +12,10 @@ namespace ttnn::operations::data_movement { PermuteDeviceOperation::program_factory_t PermuteDeviceOperation::select_program_factory( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return SingleCore{}; + if (operation_attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1) { + return SingleCore{}; + } + return SingleCoreWidthPermute{}; } void PermuteDeviceOperation::validate_on_program_cache_miss( @@ -20,10 +23,6 @@ void PermuteDeviceOperation::validate_on_program_cache_miss( TT_FATAL( attributes.dims.size() == tensor_args.input_tensor.get_logical_shape().rank(), "Permute dimensions must match input tensor rank"); - TT_FATAL( - attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1, - "Last dimension of permute must be the last dimension of the input tensor as page-breaking is not supported at " - "the moment"); TT_FATAL(tensor_args.input_tensor.is_sharded() == false, "Permute operation does not support sharded input tensor"); TT_FATAL( tensor_args.input_tensor.get_layout() == Layout::ROW_MAJOR, "Permute operation only supports row-major layout"); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index 2f9481feb8c..91756172a4f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -49,7 +49,28 @@ struct PermuteDeviceOperation { const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value); }; - using program_factory_t = std::variant; + + struct SingleCoreWidthPermute { + // Shared variables are the variables that are shared between the create and override_runtime_arguments methods + struct shared_variables_t { + KernelHandle unary_reader_kernel_id; + KernelHandle unary_writer_kernel_id; + }; + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + using program_factory_t = std::variant; // Mandatory methods diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 6ae2596bcf5..3f59a304582 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -14,8 +14,9 @@ uint32_t num_pages(const ttnn::Tensor& input_tensor) { } uint32_t page_size(const ttnn::Tensor& input_tensor) { + auto BUFFER_ALIGNMENT = input_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? DRAM_ALIGNMENT : L1_ALIGNMENT; const auto& padded_shape = input_tensor.get_logical_shape(); // in anticipation of RM padding - return padded_shape[-1] * input_tensor.element_size(); + return tt::round_up(padded_shape[-1] * input_tensor.element_size(), BUFFER_ALIGNMENT); } std::vector get_row_strides(const ttnn::SimpleShape& shape) { @@ -27,6 +28,7 @@ std::vector get_row_strides(const ttnn::SimpleShape& shape) { } return strides; } + } // namespace detail PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::SingleCore::create( @@ -130,4 +132,131 @@ void PermuteDeviceOperation::SingleCore::override_runtime_arguments( } } +PermuteDeviceOperation::SingleCoreWidthPermute::cached_program_t PermuteDeviceOperation::SingleCoreWidthPermute::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& input_tensor = tensor_args.input_tensor; + auto& output_tensor = tensor_return_value; + + auto src_buffer = input_tensor.buffer(); + auto dst_buffer = output_tensor.buffer(); + + tt::tt_metal::Program program{}; + + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + uint32_t input_rm_page_size = detail::page_size(input_tensor); + + tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); + uint32_t output_rm_page_size = detail::page_size(tensor_return_value); + + uint32_t num_input_pages = detail::num_pages(input_tensor); + + tt::tt_metal::Device* device = input_tensor.device(); + + uint32_t src0_cb_index = tt::CBIndex::c_0; + uint32_t src1_cb_index = tt::CBIndex::c_1; + uint32_t num_input_pages_to_read = 2; + + // we are focused on reading one row at a time, in a pattern that allows us to write an entire output row at a time + // if W is being swapped with another dim X (e.g. H), then we need to read X rows at a time (X is the new row dimension) + // CB is thus X pages in size (X*W*element_size) + // we read in X input rows of size W, and write out W output rows of size X + // find the new row dimension (X) + + uint32_t x_dim = operation_attributes.dims.back(); + uint32_t X = input_tensor.get_logical_shape()[x_dim]; + // stride from one row to the next for each dim in the input tensor + auto input_strides = detail::get_row_strides(input_tensor.get_logical_shape()); + uint32_t X_stride = input_strides[x_dim]; + + auto output_strides = detail::get_row_strides(output_tensor.get_logical_shape()); + // after we transpose X and W, we need to stride from one row to the next for each dim in the output tensor + uint32_t W = input_tensor.get_logical_shape()[-1]; + uint32_t W_stride = output_strides[x_dim]; + + CoreRange core({0, 0}, {0, 0}); + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig( + num_input_pages_to_read * input_rm_page_size * X, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, input_rm_page_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig( + num_input_pages_to_read * output_rm_page_size * W, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, output_rm_page_size); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); + + uint32_t N = operation_attributes.dims.size(); + uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; + + bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = {(uint32_t)src_is_dram, N, input_rm_page_size, num_rows, x_dim}; + + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp", + core, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {(std::uint32_t)dst_is_dram, N, output_rm_page_size, num_rows, X, X_stride, x_dim, W, W_stride, input_rm_page_size, input_tensor.element_size()}; + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp", + core, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + auto input_shape_view = input_tensor.get_logical_shape().view(); + + std::vector reader_runtime_args = {src_buffer->address()}; + reader_runtime_args.insert(reader_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); + reader_runtime_args.insert(reader_runtime_args.end(), input_strides.begin(), input_strides.end()); + + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + + + std::vector writer_runtime_args = {dst_buffer->address()}; + writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); + writer_runtime_args.insert( + writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); + writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); + + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + + return { + std::move(program), + {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; +} + +void PermuteDeviceOperation::SingleCoreWidthPermute::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + auto& program = cached_program.program; + auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; + auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + + const auto& input_tensor = tensor_args.input_tensor; + auto& output_tensor = tensor_return_value; + + auto src_buffer = input_tensor.buffer(); + auto dst_buffer = output_tensor.buffer(); + + { + auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0}); + runtime_args[0] = src_buffer->address(); + } + + { + auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0}); + runtime_args[0] = dst_buffer->address(); + } +} + } // namespace ttnn::operations::data_movement From 7b7b12f927c5e163905d5ae8360768a11be2beed Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 3 Dec 2024 20:10:56 +0000 Subject: [PATCH 05/20] #15165: add some N-d permute optimizations --- .../unit_tests/operations/test_permute.py | 5 +- ...r_permute_interleaved_rm_width_permute.cpp | 47 ++++---- ...r_permute_interleaved_rm_width_permute.cpp | 111 +++++++----------- .../device/permute_program_factory.cpp | 1 + 4 files changed, 69 insertions(+), 95 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 186f1e2f4c5..aa394c98af5 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -188,14 +188,15 @@ def generate_permutations(N): @pytest.mark.parametrize("shape", [(7, 7, 7, 7, 7)]) @pytest.mark.parametrize("perm", generate_permutations(5)) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) -def test_permute_5d_width(shape, perm, memory_config, device): +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32]) +def test_permute_5d_width(shape, perm, memory_config, dtype, device): torch.manual_seed(2005) input_a = torch.randn(shape) # print(input_a) torch_output = torch.permute(input_a, perm) tt_input = ttnn.from_torch( - input_a, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.bfloat16, memory_config=memory_config + input_a, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=dtype, memory_config=memory_config ) tt_output = ttnn.permute(tt_input, perm) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp index 0c060289dab..be32533c2d3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp @@ -37,51 +37,46 @@ void kernel_main() { uint32_t X = input_shape[x_dim]; uint32_t X_stride = src_strides[x_dim]; - // for (uint32_t i = 0; i < N; i++) { - // DPRINT << "input_shape[" << i << "] = " << input_shape[i] << " "; - // } - // DPRINT << ENDL(); - // for (uint32_t i = 0; i < N; i++) { - // DPRINT << "src_strides[" << i << "] = " << src_strides[i] << " "; - // } - // DPRINT << ENDL(); - const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = page_size}; uint32_t curr_addr = src_addr; - // DPRINT << "Reading " << num_rows << " rows of " << X << " elements each" << ENDL(); - // DPRINT << "X dimension: " << x_dim << ENDL(); + uint32_t idxs[N]; + idxs[N - 1] = 0; for (uint32_t i = 0; i < num_rows/X; ++i) { - uint32_t idxs[N]; - idxs[N - 1] = 0; + // Map linear index i to multidimensional indices idxs[] uint32_t remainder = i; for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension if (d == (int32_t)x_dim) { - continue; // Skip X dimension + idxs[d] = 0; // Initialize x_dim to zero (will be set in inner loop) + continue; // Skip x_dim during mapping } idxs[d] = remainder % input_shape[d]; remainder /= input_shape[d]; } + idxs[N - 1] = 0; // Initialize W dimension index to zero if not already set + + // Precompute the base address offset (excluding x_dim) + uint64_t base_addr_offset = 0; + for (uint32_t d = 0; d < N; ++d) { + if (d != x_dim) { + base_addr_offset += idxs[d] * src_strides[d]; + } + } + cb_reserve_back(tt::CBIndex::c_0, X); uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); + + // Read along the X dimension for (uint32_t j = 0; j < X; ++j) { - idxs[x_dim] = j; - // for (uint32_t k = 0; k < N; ++k) { - // DPRINT << "idxs[" << k << "] = " << idxs[k] << " "; - // } - // Compute the address using indices and strides - uint64_t addr_offset = 0; - for (uint32_t d = 0; d < N; ++d) { - addr_offset += idxs[d] * src_strides[d]; - } - // DPRINT << "Reading page " << addr_offset << " into buffer " << ENDL(); + // Set the index for the X dimension + uint32_t idx_x = j; + // Compute the address offset for this index + uint64_t addr_offset = base_addr_offset + idx_x * X_stride; uint64_t src_noc_addr = get_noc_addr(addr_offset, s0); noc_async_read(src_noc_addr, src_buffer_l1_addr, page_size); src_buffer_l1_addr += page_size; } noc_async_read_barrier(); - // src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); - // print_pages(src_buffer_l1_addr, 8, X, 0); cb_push_back(tt::CBIndex::c_0, X); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp index 03fd8a78e9f..fdc61bc04c4 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp @@ -56,34 +56,22 @@ void kernel_main() { constexpr uint32_t num_rows = get_compile_time_arg_val(3); constexpr uint32_t X = get_compile_time_arg_val(4); - // X_stride and x_dim along the input tensor constexpr uint32_t X_stride = get_compile_time_arg_val(5); constexpr uint32_t x_dim = get_compile_time_arg_val(6); constexpr uint32_t W = get_compile_time_arg_val(7); - // W_stride and w_dim along the output tensor constexpr uint32_t W_stride = get_compile_time_arg_val(8); constexpr uint32_t input_page_size = get_compile_time_arg_val(9); constexpr uint32_t element_size_bytes = get_compile_time_arg_val(10); constexpr uint32_t w_dim = N - 1; - // // DPRINT << "N = " << N << ENDL(); - // DPRINT << "output_page_size = " << output_page_size << ENDL(); - // // DPRINT << "num_rows = " << num_rows << ENDL(); - // // DPRINT << "X = " << X << ENDL(); - // // DPRINT << "X_stride = " << X_stride << ENDL(); - // // DPRINT << "x_dim = " << x_dim << ENDL(); - // // DPRINT << "W = " << W << ENDL(); - // DPRINT << "W_stride = " << W_stride << ENDL(); - // // DPRINT << "w_dim = " << w_dim << ENDL() << ENDL(); - // DPRINT << "input_page_size = " << input_page_size << ENDL(); - // DPRINT << "element_size_bytes = " << element_size_bytes << ENDL(); - - const uint32_t dst_addr = get_arg_val(0); - const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = output_page_size}; + const InterleavedAddrGen s0 = { + .bank_base_address = dst_addr, + .page_size = output_page_size + }; uint32_t input_shape[N], perm[N], dest_strides[N]; for (uint32_t i = 1; i <= N; i++) { @@ -92,83 +80,72 @@ void kernel_main() { dest_strides[i - 1] = get_arg_val(i + 2 * N); } - // after we transpose w and x, the perm is different - // perm[i] = x_dim becomes perm[i] == N - 1 as it's been in the correct position (i == N - 1) - // perm[j] = N - 1 becomes perm[j] = x_dim as x_dim has been moved to the end - // input shape[i] = X becomes input_shape[i] = W as we're transposing the X dimension - // input shape[j] = W becomes input_shape[j] = X as we're transposing the W dimension - // dest strides is slightly incorrect as it assumes W is in its final position + // Adjust for the transpose between X and W dimensions swap_elements(input_shape, x_dim, w_dim); for (uint32_t i = 0; i < N; i++) { if (perm[i] == x_dim) { - perm[i] = N - 1; - } else if (perm[i] == N - 1) { + perm[i] = w_dim; + } else if (perm[i] == w_dim) { perm[i] = x_dim; } } - // for (uint32_t i = 0; i < N; i++) { - // DPRINT << "input_shape[" << i << "] = " << input_shape[i] << " "; - // } - // DPRINT << ENDL(); - // for (uint32_t i = 0; i < N; i++) { - // DPRINT << "perm[" << i << "] = " << perm[i] << " "; - // } - // DPRINT << ENDL(); - // for (uint32_t i = 0; i < N; i++) { - // DPRINT << "dest_strides[" << i << "] = " << dest_strides[i] << " "; - // } - // DPRINT << ENDL() << ENDL(); - - - uint32_t curr_addr = dst_addr; - for (uint32_t block = 0; block < num_rows/X; ++block) { - // Compute multi-dimensional index for the source block - uint32_t src_multi_idx[N]; + uint32_t x_dim_in_dest = N; // Invalid index + for (uint32_t i = 0; i < N; ++i) { + if (perm[i] == x_dim) { + x_dim_in_dest = i; + break; + } + } + uint32_t transposed_buffer_read_addr = get_read_ptr(tt::CBIndex::c_1); + uint32_t src_multi_idx[N] = {0}; + uint32_t dest_multi_idx[N] = {0}; + for (uint32_t block = 0; block < num_rows / X; ++block) { + // Compute source indices size_t remaining = block; for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension if (d == (int32_t)x_dim) { - continue; // Skip post-XW transpose w dimension + continue; // Skip x_dim } src_multi_idx[d] = remaining % input_shape[d]; remaining /= input_shape[d]; } - // Apply permutation to get destination multi-dimensional index - src_multi_idx[N - 1] = 0; // Row dimension index - src_multi_idx[N - 1] = 0; // Row dimension index + // Precompute dest_multi_idx and dest_linear_idx_base + uint32_t dest_linear_idx_base = 0; + for (uint32_t i = 0; i < N; ++i) { + uint32_t src_idx = perm[i]; + if (src_idx != x_dim) { + dest_multi_idx[i] = src_multi_idx[src_idx]; + if (i < N - 1) { // Exclude W dimension + dest_linear_idx_base += dest_multi_idx[i] * dest_strides[i]; + } + } + } + cb_wait_front(tt::CBIndex::c_0, X); uint32_t src_buffer_l1_addr = get_read_ptr(tt::CBIndex::c_0); - print_pages(src_buffer_l1_addr, tt::data_movement::common::round_up(), X, 0); - // Transpose the X*W*element_size block - uint32_t transposed_buffer_read_addr = get_read_ptr(tt::CBIndex::c_1); + + // Transpose the block transpose_XW_to_WX(src_buffer_l1_addr, transposed_buffer_read_addr, X, W, element_size_bytes, input_page_size, output_page_size); - print_pages(transposed_buffer_read_addr, tt::data_movement::common::round_up(), W, 0); + + + + // Update only the changing components inside the loop for (uint32_t w = 0; w < W; ++w) { src_multi_idx[x_dim] = w; - uint32_t dest_multi_idx[N]; - for (uint32_t i = 0; i < N; ++i) { - dest_multi_idx[i] = src_multi_idx[perm[i]]; - } - for (uint32_t i = 0; i < N; i++) { - DPRINT << "src_multi_idx[" << i << "] = " << src_multi_idx[i] << " "; - } - DPRINT << ENDL(); - for (uint32_t i = 0; i < N; i++) { - DPRINT << "dest_multi_idx[" << i << "] = " << dest_multi_idx[i] << " "; - } - DPRINT << ENDL(); + dest_multi_idx[x_dim_in_dest] = w; - // Convert destination multi-dimensional index to linear index - uint32_t dest_linear_idx = 0; - for (uint32_t i = 0; i < N - 1; ++i) { - dest_linear_idx += dest_multi_idx[i] * dest_strides[i]; + // Update dest_linear_idx + uint32_t dest_linear_idx = dest_linear_idx_base; + if (x_dim_in_dest < N - 1) { // Exclude W dimension + dest_linear_idx += dest_multi_idx[x_dim_in_dest] * dest_strides[x_dim_in_dest]; } uint64_t dst_noc_addr = get_noc_addr(dest_linear_idx, s0); noc_async_write(transposed_buffer_read_addr + w * output_page_size, dst_noc_addr, output_page_size); - noc_async_write_barrier(); } + noc_async_write_barrier(); cb_pop_front(tt::CBIndex::c_0, X); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 3f59a304582..fe59b03d16f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -4,6 +4,7 @@ #include "ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp" #include "tt_metal/common/work_split.hpp" +#include "noc/noc_parameters.h" // DRAM_ALIGNMENT namespace ttnn::operations::data_movement { From 21d29cabb60a0835c7c68c4193df75d73f648d49 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Thu, 5 Dec 2024 17:26:35 +0000 Subject: [PATCH 06/20] #15165: make N-d permute on width multicore --- .../unit_tests/operations/test_permute.py | 21 ++ ...permute_interleaved_rm_blocked_generic.cpp | 114 +++++++++++ ...permute_interleaved_rm_blocked_generic.cpp | 187 ++++++++++++++++++ ...r_permute_interleaved_rm_width_permute.cpp | 22 ++- .../device/permute_device_operation.cpp | 2 +- .../device/permute_device_operation.hpp | 23 ++- .../device/permute_program_factory.cpp | 185 +++++++++++++++++ 7 files changed, 550 insertions(+), 4 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index aa394c98af5..806338a8c93 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -202,3 +202,24 @@ def test_permute_5d_width(shape, perm, memory_config, dtype, device): tt_output = ttnn.permute(tt_input, perm) tt_output = ttnn.to_torch(tt_output) assert_with_pcc(torch_output, tt_output, 0.9999) + + +@pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65)]) +@pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1)]) +@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG]) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16]) +def test_permute_5d_blocked(shape, perm, memory_config, dtype, device): + torch.manual_seed(520) + torch.set_printoptions(threshold=10000, precision=2, linewidth=1000) + input_a = torch.randn(shape) + + torch_output = torch.permute(input_a, perm) + + tt_input = ttnn.from_torch( + input_a, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=dtype, memory_config=memory_config + ) + + tt_output = ttnn.permute(tt_input, perm) + tt_output = ttnn.to_torch(tt_output) + + assert_with_pcc(torch_output, tt_output, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp new file mode 100644 index 00000000000..f6a86b80560 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "debug/dprint.h" + +inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { + DPRINT << BF16(*ptr) << " "; + } + DPRINT << ENDL(); + } + DPRINT << ENDL(); +} + +void kernel_main() { + constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(0); + constexpr uint32_t N = get_compile_time_arg_val(1); + constexpr uint32_t input_cb_page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_rows = get_compile_time_arg_val(3); + constexpr uint32_t x_dim = get_compile_time_arg_val(4); + constexpr uint32_t num_blocks_total = get_compile_time_arg_val(5); + constexpr uint32_t x_blocks = get_compile_time_arg_val(6); + constexpr uint32_t w_blocks = get_compile_time_arg_val(7); + constexpr uint32_t x_block_size = get_compile_time_arg_val(8); + constexpr uint32_t w_block_size = get_compile_time_arg_val(9); + constexpr uint32_t element_size = get_compile_time_arg_val(10); + constexpr uint32_t input_tensor_page_size = get_compile_time_arg_val(11); + + constexpr uint32_t w_block_size_bytes = w_block_size * element_size; + + const uint32_t src_addr = get_arg_val(0); + + uint32_t start_block = get_arg_val(1); + uint32_t end_block = get_arg_val(2); + + uint32_t input_shape[N], src_strides[N]; + for (uint32_t i = 3; i < N + 3; i++) { + input_shape[i - 3] = get_arg_val(i); + src_strides[i - 3] = get_arg_val(i + N); + } + + /** + * num_blocks_total blocks in the tensor which are rows_before X, X blocks, rows after X and W blocks + * collapse rows_before and rows_after into a single rows variable + * rows * X blocks * W blocks = num_blocks_total + */ + + uint32_t X = input_shape[x_dim]; + uint32_t X_stride = src_strides[x_dim]; + + const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = input_tensor_page_size}; + + uint32_t curr_addr = src_addr; + uint32_t idxs[N]; + idxs[N - 1] = 0; + for (uint32_t block = start_block; block < end_block; ++block) { + uint32_t w_block = block % w_blocks; + uint32_t rem = block / w_blocks; + uint32_t x_block = rem % x_blocks; + rem = rem / x_blocks; + uint32_t xw_block = rem % (num_rows / X); + uint32_t remainder = xw_block; + + uint32_t x_start = x_block * x_block_size; + uint32_t x_end = min(x_start + x_block_size, X); + + uint32_t w_start = w_block * w_block_size; + uint32_t w_end = min(w_start + w_block_size, input_shape[N - 1]); + uint32_t w_offset = w_start * element_size; + + uint32_t w_read_size_bytes = (w_end - w_start) * element_size; + + // Map linear index i to multidimensional indices idxs[] + for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension + if (d == (int32_t)x_dim) { + idxs[d] = 0; // Initialize x_dim to zero (will be set in inner loop) + continue; // Skip x_dim during mapping + } + idxs[d] = remainder % input_shape[d]; + remainder /= input_shape[d]; + } + idxs[N - 1] = 0; // Initialize W dimension index to zero if not already set + + // Precompute the base address offset (excluding x_dim) + uint64_t base_addr_offset = 0; + for (uint32_t d = 0; d < N; ++d) { + if (d != x_dim) { + base_addr_offset += idxs[d] * src_strides[d]; + } + } + + cb_reserve_back(tt::CBIndex::c_0, x_block_size); + uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); + uint32_t page_offset = 0; + // Read along the X dimension + for (uint32_t x = x_start; x < x_end; ++x) { + // Set the index for the X dimension + uint32_t idx_x = x; + // Compute the address offset for this index + uint64_t addr_offset = base_addr_offset + idx_x * X_stride; + uint64_t src_noc_addr = get_noc_addr(addr_offset, s0, w_offset); + noc_async_read(src_noc_addr, src_buffer_l1_addr + page_offset, w_read_size_bytes); + page_offset += input_cb_page_size; + } + noc_async_read_barrier(); + cb_push_back(tt::CBIndex::c_0, x_block_size); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp new file mode 100644 index 00000000000..8c6a9e2e3de --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp @@ -0,0 +1,187 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" + +// Function template to swap two elements in a uint32_t array +template +FORCE_INLINE void swap_elements(uint32_t (&array)[N], size_t i, size_t j) { + // Perform the swap + uint32_t temp = array[i]; + array[i] = array[j]; + array[j] = temp; +} + +inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { + DPRINT << BF16(*ptr) << " "; + } + DPRINT << ENDL(); + } + DPRINT << ENDL() << ENDL(); +} + +FORCE_INLINE void transpose_XW_to_WX( + uint32_t input_l1_addr, + uint32_t output_l1_addr, + uint32_t X, + uint32_t W, + uint32_t element_size, + uint32_t input_page_size, + uint32_t output_page_size) { + volatile tt_l1_ptr uint8_t* input_ptr = reinterpret_cast(input_l1_addr); + volatile tt_l1_ptr uint8_t* output_ptr = reinterpret_cast(output_l1_addr); + // transpose from XW, where X is outer and W inner, to WX, where W is outer and X is inner + // each element is element_size bytes + // each row is W elements, and each row is separated by input_page_size bytes + // each output row is X elements, and each row is separated by output_page_size bytes + + for (uint32_t x = 0; x < X; ++x) { + for (uint32_t w = 0; w < W; ++w) { + // Compute the input and output addresses + uint32_t input_addr = x * input_page_size + w * element_size; + uint32_t output_addr = w * output_page_size + x * element_size; + // Copy the element - do we have memcpy? use this for now + for (uint32_t i = 0; i < element_size; ++i) { + output_ptr[output_addr + i] = input_ptr[input_addr + i]; + } + } + } +} + +void kernel_main() { + constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(0); + constexpr uint32_t N = get_compile_time_arg_val(1); + constexpr uint32_t output_cb_page_size = get_compile_time_arg_val(2); + constexpr uint32_t num_rows = get_compile_time_arg_val(3); + + constexpr uint32_t X = get_compile_time_arg_val(4); + constexpr uint32_t X_stride = get_compile_time_arg_val(5); + constexpr uint32_t x_dim = get_compile_time_arg_val(6); + + constexpr uint32_t W_stride = get_compile_time_arg_val(7); + constexpr uint32_t input_cb_page_size = get_compile_time_arg_val(8); + constexpr uint32_t element_size_bytes = get_compile_time_arg_val(9); + + constexpr uint32_t num_blocks_total = get_compile_time_arg_val(10); + constexpr uint32_t x_blocks = get_compile_time_arg_val(11); + constexpr uint32_t w_blocks = get_compile_time_arg_val(12); + constexpr uint32_t x_block_size = get_compile_time_arg_val(13); + constexpr uint32_t w_block_size = get_compile_time_arg_val(14); + constexpr uint32_t W = get_compile_time_arg_val(15); + constexpr uint32_t output_tensor_page_size = get_compile_time_arg_val(16); + + constexpr uint32_t x_block_size_bytes = x_block_size * element_size_bytes; + constexpr uint32_t w_dim = N - 1; + + const uint32_t dst_addr = get_arg_val(0); + const uint32_t start_block = get_arg_val(1); + const uint32_t end_block = get_arg_val(2); + + const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = output_tensor_page_size}; + + uint32_t input_shape[N], perm[N], dest_strides[N]; + for (uint32_t i = 3; i < N + 3; i++) { + input_shape[i - 3] = get_arg_val(i); + perm[i - 3] = get_arg_val(i + N); + dest_strides[i - 3] = get_arg_val(i + 2 * N); + } + + // Adjust for the transpose between X and W dimensions + swap_elements(input_shape, x_dim, w_dim); + for (uint32_t i = 0; i < N; i++) { + if (perm[i] == x_dim) { + perm[i] = w_dim; + } else if (perm[i] == w_dim) { + perm[i] = x_dim; + } + } + + uint32_t x_dim_in_dest = N; // Invalid index + for (uint32_t i = 0; i < N; ++i) { + if (perm[i] == x_dim) { + x_dim_in_dest = i; + break; + } + } + uint32_t transposed_buffer_read_addr = get_read_ptr(tt::CBIndex::c_1); + uint32_t src_multi_idx[N] = {0}; + uint32_t dest_multi_idx[N] = {0}; + for (uint32_t block = start_block; block < end_block; ++block) { + uint32_t w_block = block % w_blocks; + uint32_t rem = block / w_blocks; + uint32_t x_block = rem % x_blocks; + rem = rem / x_blocks; + uint32_t xw_block = rem % (num_rows / X); + // Map linear index i to multidimensional indices idxs[] + uint32_t remainder = xw_block; + + uint32_t x_start = x_block * x_block_size; + uint32_t x_end = min(x_start + x_block_size, X); + uint32_t x_offset = x_start * element_size_bytes; + + uint32_t w_start = w_block * w_block_size; + uint32_t w_end = min(w_start + w_block_size, W); + + uint32_t x_read_size_bytes = (x_end - x_start) * element_size_bytes; + + // Compute source indices + size_t remaining = xw_block; + for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension + if (d == (int32_t)x_dim) { + continue; // Skip x_dim + } + src_multi_idx[d] = remaining % input_shape[d]; + remaining /= input_shape[d]; + } + + // Precompute dest_multi_idx and dest_linear_idx_base + uint32_t dest_linear_idx_base = 0; + for (uint32_t i = 0; i < N; ++i) { + uint32_t src_idx = perm[i]; + if (src_idx != x_dim) { + dest_multi_idx[i] = src_multi_idx[src_idx]; + if (i < N - 1) { // Exclude W dimension + dest_linear_idx_base += dest_multi_idx[i] * dest_strides[i]; + } + } + } + + cb_wait_front(tt::CBIndex::c_0, x_block_size); + uint32_t src_buffer_l1_addr = get_read_ptr(tt::CBIndex::c_0); + print_pages(src_buffer_l1_addr, 32, 32); + // Transpose the block + transpose_XW_to_WX( + src_buffer_l1_addr, + transposed_buffer_read_addr, + x_block_size, + w_block_size, + element_size_bytes, + input_cb_page_size, + output_cb_page_size); + print_pages(transposed_buffer_read_addr, 32, 32); + // Update only the changing components inside the loop + for (uint32_t w = w_start; w < w_end; ++w) { + src_multi_idx[x_dim] = w; + dest_multi_idx[x_dim_in_dest] = w; + + // Update dest_linear_idx + uint32_t dest_linear_idx = dest_linear_idx_base; + if (x_dim_in_dest < N - 1) { // Exclude W dimension + dest_linear_idx += dest_multi_idx[x_dim_in_dest] * dest_strides[x_dim_in_dest]; + } + uint64_t dst_noc_addr = get_noc_addr(dest_linear_idx, s0, x_offset); + uint32_t l1_addr = transposed_buffer_read_addr + (w - w_start) * output_cb_page_size; + noc_async_write(l1_addr, dst_noc_addr, x_read_size_bytes); + } + noc_async_write_barrier(); + cb_pop_front(tt::CBIndex::c_0, x_block_size); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp index fdc61bc04c4..552ae483651 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp @@ -68,6 +68,20 @@ void kernel_main() { const uint32_t dst_addr = get_arg_val(0); + DPRINT << "N = " << N << ENDL(); + DPRINT << "page_size = " << output_page_size << ENDL(); + DPRINT << "num_rows = " << num_rows << ENDL(); + DPRINT << "x_dim = " << x_dim << ENDL(); + DPRINT << "X = " << X << ENDL(); + DPRINT << "X_stride = " << X_stride << ENDL(); + DPRINT << "x_dim = " << x_dim << ENDL(); + DPRINT << "W = " << W << ENDL(); + DPRINT << "W_stride = " << W_stride << ENDL(); + DPRINT << "input_page_size = " << input_page_size << ENDL(); + DPRINT << "element_size_bytes = " << element_size_bytes << ENDL(); + DPRINT << "w_dim = " << w_dim << ENDL(); + DPRINT << "dst_addr = " << dst_addr << ENDL(); + const InterleavedAddrGen s0 = { .bank_base_address = dst_addr, .page_size = output_page_size @@ -135,15 +149,19 @@ void kernel_main() { for (uint32_t w = 0; w < W; ++w) { src_multi_idx[x_dim] = w; dest_multi_idx[x_dim_in_dest] = w; - + // for (uint32_t i = 0; i < N; ++i) { + // DPRINT << "dest_multi_idx[" << i << "] = " << dest_multi_idx[i] << " "; + // } + // DPRINT << ENDL(); // Update dest_linear_idx uint32_t dest_linear_idx = dest_linear_idx_base; if (x_dim_in_dest < N - 1) { // Exclude W dimension dest_linear_idx += dest_multi_idx[x_dim_in_dest] * dest_strides[x_dim_in_dest]; } - + DPRINT << "dest_linear_idx = " << dest_linear_idx << ENDL(); uint64_t dst_noc_addr = get_noc_addr(dest_linear_idx, s0); noc_async_write(transposed_buffer_read_addr + w * output_page_size, dst_noc_addr, output_page_size); + DPRINT << ENDL(); } noc_async_write_barrier(); cb_pop_front(tt::CBIndex::c_0, X); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp index a3345096f54..af6ee177fca 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp @@ -15,7 +15,7 @@ PermuteDeviceOperation::program_factory_t PermuteDeviceOperation::select_program if (operation_attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1) { return SingleCore{}; } - return SingleCoreWidthPermute{}; + return MultiCoreBlockedGeneric{}; } void PermuteDeviceOperation::validate_on_program_cache_miss( diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index 91756172a4f..9454a09a0e2 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -70,7 +70,28 @@ struct PermuteDeviceOperation { tensor_return_value_t& tensor_return_value); }; - using program_factory_t = std::variant; + struct MultiCoreBlockedGeneric { + // Shared variables are the variables that are shared between the create and override_runtime_arguments methods + struct shared_variables_t { + KernelHandle unary_reader_kernel_id; + KernelHandle unary_writer_kernel_id; + CoreRangeSet all_cores; + }; + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value); + }; + + using program_factory_t = std::variant; // Mandatory methods diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index fe59b03d16f..a11a804118e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -260,4 +260,189 @@ void PermuteDeviceOperation::SingleCoreWidthPermute::override_runtime_arguments( } } +PermuteDeviceOperation::MultiCoreBlockedGeneric::cached_program_t +PermuteDeviceOperation::MultiCoreBlockedGeneric::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + using namespace tt; + using namespace tt::tt_metal; + + const auto& input_tensor = tensor_args.input_tensor; + auto& output_tensor = tensor_return_value; + + auto src_buffer = input_tensor.buffer(); + auto dst_buffer = output_tensor.buffer(); + + tt::tt_metal::Program program{}; + + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + uint32_t w_block_size = constants::TILE_WIDTH; + uint32_t input_cb_page_size = w_block_size * input_tensor.element_size(); + + tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); + uint32_t x_block_size = constants::TILE_HEIGHT; + uint32_t output_cb_page_size = x_block_size * input_tensor.element_size(); + + tt::tt_metal::Device* device = input_tensor.device(); + + uint32_t src0_cb_index = tt::CBIndex::c_0; + uint32_t src1_cb_index = tt::CBIndex::c_1; + uint32_t num_input_pages_to_read = 2; + + // we are focused on reading one row at a time, in a pattern that allows us to write an entire output row at a time + // if W is being swapped with another dim X (e.g. H), then we need to read X rows at a time (X is the new row + // dimension) CB is thus X pages in size (X*W*element_size) we read in X input rows of size W, and write out W + // output rows of size X find the new row dimension (X) + + uint32_t x_dim = operation_attributes.dims.back(); + uint32_t X = input_tensor.get_logical_shape()[x_dim]; + // stride from one row to the next for each dim in the input tensor + auto input_strides = detail::get_row_strides(input_tensor.get_logical_shape()); + uint32_t X_stride = input_strides[x_dim]; + + auto output_strides = detail::get_row_strides(output_tensor.get_logical_shape()); + // after we transpose X and W, we need to stride from one row to the next for each dim in the output tensor + uint32_t W = input_tensor.get_logical_shape()[-1]; + uint32_t W_stride = output_strides[x_dim]; + + uint32_t N = operation_attributes.dims.size(); + uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; + + // treat the input tensor as 3D with rows * x_blocks * w_blocks + uint32_t x_blocks = tt::div_up(X, x_block_size); + uint32_t w_blocks = tt::div_up(W, w_block_size); + uint32_t num_blocks_total = (num_rows / X) * x_blocks * w_blocks; + + auto compute_with_storage_grid_size = input_tensor.device()->compute_with_storage_grid_size(); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_blocks_total); + + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig( + num_input_pages_to_read * input_cb_page_size * x_block_size, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, input_cb_page_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig( + num_input_pages_to_read * output_cb_page_size * w_block_size, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, output_cb_page_size); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); + + bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = { + (uint32_t)src_is_dram, + N, + input_cb_page_size, + num_rows, + x_dim, + num_blocks_total, + x_blocks, + w_blocks, + x_block_size, + w_block_size, + input_tensor.element_size(), + input_tensor.get_logical_shape()[-1] * input_tensor.element_size()}; + + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "reader_permute_interleaved_rm_blocked_generic.cpp", + all_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); + + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = { + (std::uint32_t)dst_is_dram, + N, + output_cb_page_size, + num_rows, + + X, + X_stride, + x_dim, + + W_stride, + input_cb_page_size, + input_tensor.element_size(), + + num_blocks_total, + x_blocks, + w_blocks, + x_block_size, + w_block_size, + + W, + output_tensor.get_logical_shape()[-1] * output_tensor.element_size()}; + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "writer_permute_interleaved_rm_blocked_generic.cpp", + all_cores, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + + auto input_shape_view = input_tensor.get_logical_shape().view(); + + auto cores = corerange_to_cores(all_cores, std::nullopt); + + uint32_t start_block = 0; + uint32_t num_blocks_per_core = 0; + for (const auto& core : cores) { + if (core_group_1.contains(core)) { + num_blocks_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.contains(core)) { + num_blocks_per_core = num_tiles_per_core_group_2; + } else { + // no-op + num_blocks_per_core = 0; + } + uint32_t end_block = start_block + num_blocks_per_core; + std::vector reader_runtime_args = {src_buffer->address(), start_block, end_block}; + reader_runtime_args.insert(reader_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); + reader_runtime_args.insert(reader_runtime_args.end(), input_strides.begin(), input_strides.end()); + + std::vector writer_runtime_args = {dst_buffer->address(), start_block, end_block}; + + writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); + writer_runtime_args.insert( + writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); + writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + start_block = end_block; + } + + return { + std::move(program), + {.unary_reader_kernel_id = unary_reader_kernel_id, + .unary_writer_kernel_id = unary_writer_kernel_id, + .all_cores = all_cores}}; +} + +void PermuteDeviceOperation::MultiCoreBlockedGeneric::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& tensor_return_value) { + auto& program = cached_program.program; + auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; + auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + + const auto& input_tensor = tensor_args.input_tensor; + auto& output_tensor = tensor_return_value; + + auto src_buffer = input_tensor.buffer(); + auto dst_buffer = output_tensor.buffer(); + auto& all_cores = cached_program.shared_variables.all_cores; + + auto cores = corerange_to_cores(all_cores, std::nullopt); + for (const auto& core : cores) { + auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args_writer[0] = dst_buffer->address(); + } +} + } // namespace ttnn::operations::data_movement From ce68210481606aeee3ee2ffc2a466694fe901a59 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Fri, 6 Dec 2024 06:45:26 +0000 Subject: [PATCH 07/20] #0: use unpack and pack for tilize, transpose, untilize --- .../unit_tests/operations/test_permute.py | 2 - .../data_movement/common/kernels/common.hpp | 39 ++++++++++ .../transpose_xh_rm_single_tile_size.cpp | 70 +++++++++++++++++ ...permute_interleaved_rm_blocked_generic.cpp | 13 ---- ...permute_interleaved_rm_blocked_generic.cpp | 77 +++---------------- .../device/permute_device_operation.hpp | 1 + .../device/permute_program_factory.cpp | 53 ++++++++++--- 7 files changed, 162 insertions(+), 93 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xh_rm_single_tile_size.cpp diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 806338a8c93..cac0e3e41ab 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -192,7 +192,6 @@ def generate_permutations(N): def test_permute_5d_width(shape, perm, memory_config, dtype, device): torch.manual_seed(2005) input_a = torch.randn(shape) - # print(input_a) torch_output = torch.permute(input_a, perm) tt_input = ttnn.from_torch( @@ -210,7 +209,6 @@ def test_permute_5d_width(shape, perm, memory_config, dtype, device): @pytest.mark.parametrize("dtype", [ttnn.bfloat16]) def test_permute_5d_blocked(shape, perm, memory_config, dtype, device): torch.manual_seed(520) - torch.set_printoptions(threshold=10000, precision=2, linewidth=1000) input_a = torch.randn(shape) torch_output = torch.permute(input_a, perm) diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index 27c68f53b18..34cf4e3eb3b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -137,4 +137,43 @@ template FORCE_INLINE constexpr uint32_t round_up() { return b * div_up(); } + +// Function template to swap two elements in a uint32_t array +template +FORCE_INLINE void swap_elements(uint32_t (&array)[N], size_t i, size_t j) { + // Perform the swap + uint32_t temp = array[i]; + array[i] = array[j]; + array[j] = temp; +} + +// 2D Transpose function for debug use in reader/writer kernels +FORCE_INLINE void transpose_2d( + uint32_t input_l1_addr, + uint32_t output_l1_addr, + uint32_t X, + uint32_t W, + uint32_t element_size, + uint32_t input_page_size, + uint32_t output_page_size) { + volatile tt_l1_ptr uint8_t* input_ptr = reinterpret_cast(input_l1_addr); + volatile tt_l1_ptr uint8_t* output_ptr = reinterpret_cast(output_l1_addr); + // transpose from XW, where X is outer and W inner, to WX, where W is outer and X is inner + // each element is element_size bytes + // each row is W elements, and each row is separated by input_page_size bytes + // each output row is X elements, and each row is separated by output_page_size bytes + + for (uint32_t x = 0; x < X; ++x) { + for (uint32_t w = 0; w < W; ++w) { + // Compute the input and output addresses + uint32_t input_addr = x * input_page_size + w * element_size; + uint32_t output_addr = w * output_page_size + x * element_size; + // Copy the element - do we have memcpy? use this for now + for (uint32_t i = 0; i < element_size; ++i) { + output_ptr[output_addr + i] = input_ptr[input_addr + i]; + } + } + } +} + } // namespace tt::data_movement::common diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xh_rm_single_tile_size.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xh_rm_single_tile_size.cpp new file mode 100644 index 00000000000..0b69361f57a --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xh_rm_single_tile_size.cpp @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" +#include "compute_kernel_api/transpose_wh.h" +#include "compute_kernel_api/tilize.h" +#include "compute_kernel_api/untilize.h" +#include "compute_kernel_api/pack_untilize.h" + +namespace NAMESPACE { +void MAIN { + constexpr uint32_t x_block_size = get_compile_time_arg_val(0); + constexpr uint32_t w_block_size = get_compile_time_arg_val(1); + + uint32_t num_blocks = get_arg_val(0); + + constexpr auto cb_in = tt::CBIndex::c_0; + constexpr auto cb_tilize = tt::CBIndex::c_1; + constexpr auto cb_out = tt::CBIndex::c_2; + + unary_op_init_common(cb_in, cb_out); + + for (uint32_t n = 0; n < num_blocks; n++) { + // have to global init here, otherwise pcc is bad + // if n > 0, then some register isn't cleared and the output of tilize_block is garbage + unary_op_init_common(cb_in, cb_out); + // tilize input via unpack and then pack + tilize_init_short(cb_in, 1); + + cb_wait_front(cb_in, x_block_size); + // results are correct according to unpacker here + cb_reserve_back(cb_tilize, 1); + + // removing this line causes the output of tilize_block to be garbage in the second iteration + tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize + + // tile slice according to unpacker is garbage after tilize_block in the second iteration, missing an uninit? + cb_push_back(cb_tilize, 1); + cb_pop_front(cb_in, x_block_size); + + tilize_uninit(cb_in); + + // transpose input + cb_wait_front(cb_tilize, 1); + transpose_wh_init_short(cb_tilize); + pack_untilize_dst_init_short<1>(cb_out); + + tile_regs_acquire(); + transpose_wh_tile(cb_tilize, 0, 0); // transpose call + tile_regs_commit(); + + // pack and untilize + cb_reserve_back(cb_out, w_block_size); + + tile_regs_wait(); + pack_untilize_dst<1>(cb_out); // pack call + tile_regs_release(); + + cb_push_back(cb_out, w_block_size); + + cb_wait_front(cb_out, w_block_size); + pack_untilize_uninit(); + + cb_pop_front(cb_tilize, 1); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp index f6a86b80560..309aafe0562 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp @@ -4,19 +4,6 @@ #include #include "dataflow_api.h" -#include "debug/dprint.h" - -inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { - volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; - for (uint32_t page = 0; page < npages; ++page) { - DPRINT << start + page << ": "; - for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { - DPRINT << BF16(*ptr) << " "; - } - DPRINT << ENDL(); - } - DPRINT << ENDL(); -} void kernel_main() { constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(0); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp index 8c6a9e2e3de..4f214ba303f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp @@ -7,55 +7,6 @@ #include "debug/dprint.h" #include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" -// Function template to swap two elements in a uint32_t array -template -FORCE_INLINE void swap_elements(uint32_t (&array)[N], size_t i, size_t j) { - // Perform the swap - uint32_t temp = array[i]; - array[i] = array[j]; - array[j] = temp; -} - -inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { - volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; - for (uint32_t page = 0; page < npages; ++page) { - DPRINT << start + page << ": "; - for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { - DPRINT << BF16(*ptr) << " "; - } - DPRINT << ENDL(); - } - DPRINT << ENDL() << ENDL(); -} - -FORCE_INLINE void transpose_XW_to_WX( - uint32_t input_l1_addr, - uint32_t output_l1_addr, - uint32_t X, - uint32_t W, - uint32_t element_size, - uint32_t input_page_size, - uint32_t output_page_size) { - volatile tt_l1_ptr uint8_t* input_ptr = reinterpret_cast(input_l1_addr); - volatile tt_l1_ptr uint8_t* output_ptr = reinterpret_cast(output_l1_addr); - // transpose from XW, where X is outer and W inner, to WX, where W is outer and X is inner - // each element is element_size bytes - // each row is W elements, and each row is separated by input_page_size bytes - // each output row is X elements, and each row is separated by output_page_size bytes - - for (uint32_t x = 0; x < X; ++x) { - for (uint32_t w = 0; w < W; ++w) { - // Compute the input and output addresses - uint32_t input_addr = x * input_page_size + w * element_size; - uint32_t output_addr = w * output_page_size + x * element_size; - // Copy the element - do we have memcpy? use this for now - for (uint32_t i = 0; i < element_size; ++i) { - output_ptr[output_addr + i] = input_ptr[input_addr + i]; - } - } - } -} - void kernel_main() { constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(0); constexpr uint32_t N = get_compile_time_arg_val(1); @@ -77,6 +28,7 @@ void kernel_main() { constexpr uint32_t w_block_size = get_compile_time_arg_val(14); constexpr uint32_t W = get_compile_time_arg_val(15); constexpr uint32_t output_tensor_page_size = get_compile_time_arg_val(16); + constexpr uint32_t cb_id_in = tt::CBIndex::c_2; constexpr uint32_t x_block_size_bytes = x_block_size * element_size_bytes; constexpr uint32_t w_dim = N - 1; @@ -95,7 +47,7 @@ void kernel_main() { } // Adjust for the transpose between X and W dimensions - swap_elements(input_shape, x_dim, w_dim); + tt::data_movement::common::swap_elements(input_shape, x_dim, w_dim); for (uint32_t i = 0; i < N; i++) { if (perm[i] == x_dim) { perm[i] = w_dim; @@ -111,16 +63,18 @@ void kernel_main() { break; } } - uint32_t transposed_buffer_read_addr = get_read_ptr(tt::CBIndex::c_1); + uint32_t src_multi_idx[N] = {0}; uint32_t dest_multi_idx[N] = {0}; for (uint32_t block = start_block; block < end_block; ++block) { + // Compute block indices uint32_t w_block = block % w_blocks; uint32_t rem = block / w_blocks; uint32_t x_block = rem % x_blocks; rem = rem / x_blocks; uint32_t xw_block = rem % (num_rows / X); - // Map linear index i to multidimensional indices idxs[] + + // Map linear index xw_block to multidimensional indices idxs[] uint32_t remainder = xw_block; uint32_t x_start = x_block * x_block_size; @@ -154,20 +108,9 @@ void kernel_main() { } } - cb_wait_front(tt::CBIndex::c_0, x_block_size); - uint32_t src_buffer_l1_addr = get_read_ptr(tt::CBIndex::c_0); - print_pages(src_buffer_l1_addr, 32, 32); - // Transpose the block - transpose_XW_to_WX( - src_buffer_l1_addr, - transposed_buffer_read_addr, - x_block_size, - w_block_size, - element_size_bytes, - input_cb_page_size, - output_cb_page_size); - print_pages(transposed_buffer_read_addr, 32, 32); - // Update only the changing components inside the loop + // Wait for transposed block + cb_wait_front(cb_id_in, w_block_size); + uint32_t transposed_buffer_read_addr = get_read_ptr(cb_id_in); for (uint32_t w = w_start; w < w_end; ++w) { src_multi_idx[x_dim] = w; dest_multi_idx[x_dim_in_dest] = w; @@ -182,6 +125,6 @@ void kernel_main() { noc_async_write(l1_addr, dst_noc_addr, x_read_size_bytes); } noc_async_write_barrier(); - cb_pop_front(tt::CBIndex::c_0, x_block_size); + cb_pop_front(cb_id_in, w_block_size); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index 9454a09a0e2..a4377fcf521 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -75,6 +75,7 @@ struct PermuteDeviceOperation { struct shared_variables_t { KernelHandle unary_reader_kernel_id; KernelHandle unary_writer_kernel_id; + KernelHandle compute_kernel_id; CoreRangeSet all_cores; }; using cached_program_t = ttnn::device_operation::CachedProgram; diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index a11a804118e..9e0765e498a 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -287,7 +287,8 @@ PermuteDeviceOperation::MultiCoreBlockedGeneric::create( tt::tt_metal::Device* device = input_tensor.device(); uint32_t src0_cb_index = tt::CBIndex::c_0; - uint32_t src1_cb_index = tt::CBIndex::c_1; + uint32_t src1_cb_index = tt::CBIndex::c_2; + uint32_t src2_cb_index = tt::CBIndex::c_1; uint32_t num_input_pages_to_read = 2; // we are focused on reading one row at a time, in a pattern that allows us to write an entire output row at a time @@ -330,6 +331,13 @@ PermuteDeviceOperation::MultiCoreBlockedGeneric::create( .set_page_size(src1_cb_index, output_cb_page_size); auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); + tt::tt_metal::CircularBufferConfig cb_src2_config = + tt::tt_metal::CircularBufferConfig( + num_input_pages_to_read * x_block_size * w_block_size * input_tensor.element_size(), + {{src2_cb_index, cb_data_format}}) + .set_page_size(src2_cb_index, x_block_size * w_block_size * input_tensor.element_size()); + auto cb_src2 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src2_config); + bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = { (uint32_t)src_is_dram, @@ -382,10 +390,33 @@ PermuteDeviceOperation::MultiCoreBlockedGeneric::create( all_cores, tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); + std::vector compute_kernel_args = {x_block_size, w_block_size}; + bool fp32_dest_acc_en = cb_data_format_output == tt::DataFormat::Float32; + auto compute_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xh_rm_single_tile_size.cpp", + all_cores, + tt::tt_metal::ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, + .compile_args = compute_kernel_args, + }); + auto input_shape_view = input_tensor.get_logical_shape().view(); + std::vector reader_runtime_args = {src_buffer->address(), 0, 0}; + reader_runtime_args.insert(reader_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); + reader_runtime_args.insert(reader_runtime_args.end(), input_strides.begin(), input_strides.end()); + + std::vector writer_runtime_args = {dst_buffer->address(), 0, 0}; + + writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); + writer_runtime_args.insert( + writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); + writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); auto cores = corerange_to_cores(all_cores, std::nullopt); + std::vector compute_runtime_args = {dst_buffer->address(), 0, 0}; + uint32_t start_block = 0; uint32_t num_blocks_per_core = 0; for (const auto& core : cores) { @@ -397,19 +428,15 @@ PermuteDeviceOperation::MultiCoreBlockedGeneric::create( // no-op num_blocks_per_core = 0; } + compute_runtime_args[0] = num_blocks_per_core; uint32_t end_block = start_block + num_blocks_per_core; - std::vector reader_runtime_args = {src_buffer->address(), start_block, end_block}; - reader_runtime_args.insert(reader_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); - reader_runtime_args.insert(reader_runtime_args.end(), input_strides.begin(), input_strides.end()); - - std::vector writer_runtime_args = {dst_buffer->address(), start_block, end_block}; - - writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); - writer_runtime_args.insert( - writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); - writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); + reader_runtime_args[1] = start_block; + reader_runtime_args[2] = end_block; + writer_runtime_args[1] = start_block; + writer_runtime_args[2] = end_block; tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, compute_kernel_id, core, compute_runtime_args); start_block = end_block; } @@ -417,6 +444,7 @@ PermuteDeviceOperation::MultiCoreBlockedGeneric::create( std::move(program), {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id, + .compute_kernel_id = compute_kernel_id, .all_cores = all_cores}}; } @@ -428,6 +456,7 @@ void PermuteDeviceOperation::MultiCoreBlockedGeneric::override_runtime_arguments auto& program = cached_program.program; auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + auto& compute_kernel_id = cached_program.shared_variables.compute_kernel_id; const auto& input_tensor = tensor_args.input_tensor; auto& output_tensor = tensor_return_value; @@ -442,6 +471,8 @@ void PermuteDeviceOperation::MultiCoreBlockedGeneric::override_runtime_arguments runtime_args[0] = src_buffer->address(); auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); runtime_args_writer[0] = dst_buffer->address(); + auto& runtime_args_compute = tt::tt_metal::GetRuntimeArgs(program, compute_kernel_id, core); + runtime_args_compute[0] = dst_buffer->address(); } } From 7f40b41e605a03c499e880e217c1191f7b343686 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Fri, 6 Dec 2024 20:51:00 +0000 Subject: [PATCH 08/20] #15589: add some extra tests and cleanup permute invoke #15750: remove composite flag --- .../unit_tests/operations/test_permute.py | 20 ++++- .../data_movement/permute/permute.cpp | 81 +++++++------------ .../data_movement/permute/permute.hpp | 1 - .../data_movement/permute/permute_pybind.cpp | 2 +- ttnn/cpp/ttnn/tensor/shape/shape.cpp | 12 +++ ttnn/cpp/ttnn/tensor/shape/shape.hpp | 2 + 6 files changed, 63 insertions(+), 55 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index cac0e3e41ab..068b067c1ac 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -203,8 +203,8 @@ def test_permute_5d_width(shape, perm, memory_config, dtype, device): assert_with_pcc(torch_output, tt_output, 0.9999) -@pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65)]) -@pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1)]) +@pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65), (1, 6, 256, 20, 50), (6, 20, 50, 1, 256)]) +@pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1), (1, 3, 4, 0, 2), (3, 0, 4, 1, 2)]) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG]) @pytest.mark.parametrize("dtype", [ttnn.bfloat16]) def test_permute_5d_blocked(shape, perm, memory_config, dtype, device): @@ -221,3 +221,19 @@ def test_permute_5d_blocked(shape, perm, memory_config, dtype, device): tt_output = ttnn.to_torch(tt_output) assert_with_pcc(torch_output, tt_output, 0.9999) + + +def test_permute_nd(device): + torch_tensor = torch.rand((1, 3, 16, 16, 16, 16), dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + output_tensor = ttnn.permute(input_tensor, (0, 2, 4, 3, 5, 1)) + output_tensor = ttnn.to_torch(output_tensor) + torch_output = torch.permute(torch_tensor, (0, 2, 4, 3, 5, 1)) + assert_with_pcc(torch_output, output_tensor, 0.9999) + + +def test_permute_squeeze(device): + tensor = ttnn.ones((1, 1, 3)) + tensor = ttnn.to_device(tensor, device) + out = ttnn.permute(tensor, (0, 1, 2)) + assert out.shape == (1, 1, 3) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp index 288f5b5a101..a33e426d5e3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp @@ -25,18 +25,9 @@ inline bool is_on_device(const Tensor& t) { ttnn::has_storage_type_of(t, ttnn::StorageType::MULTI_DEVICE); } -inline bool has_tile_padding(const Tensor& t) { - if (t.get_logical_shape().rank() > 1) { - auto the_shape = t.get_logical_shape(); - auto the_shape_with_padding = t.get_padded_shape(); - return the_shape[-1] != the_shape_with_padding[-1] or the_shape[-2] != the_shape_with_padding[-2]; - } - return false; -} - ttnn::Tensor permute_impl( const ttnn::Tensor& a, - const SmallVector& dims, + const tt::stl::Span& dims, const MemoryConfig& output_mem_config, const std::optional& pad_value) { using ttnn::operations::experimental::auto_format::AutoFormat; @@ -57,7 +48,8 @@ ttnn::Tensor permute_impl( TT_FATAL( !(pad_value.has_value() && pad_value.value() != 0.0f), "Non-zero padding is not supported for permute on tensors with rank > 4."); - input = ttnn::prim::permute(input, dims, output_mem_config, std::nullopt); + SmallVector permute_dims(dims.begin(), dims.end()); + input = ttnn::prim::permute(input, permute_dims, output_mem_config, std::nullopt); return ttnn::to_layout(input, a.get_layout(), std::nullopt, std::nullopt, (Device*)nullptr); } @@ -148,7 +140,7 @@ ttnn::Tensor permute_impl( ttnn::Tensor permute_launch( const ttnn::Tensor& a, - tt::stl::Span dims, + tt::stl::Span dims, const MemoryConfig& output_mem_config, const std::optional& pad_value) { std::vector output_tensors = {ttnn::Tensor(operation::get_workers_for_op_output({a}))}; @@ -159,31 +151,21 @@ ttnn::Tensor permute_launch( const std::vector>& optional_output_tensors) mutable -> std::vector { auto& a = input_tensors.at(0); - SmallVector normalized_dims(dims.size()); - std::transform(dims.begin(), dims.end(), normalized_dims.begin(), [a](std::int64_t idx) { - return a.get_legacy_shape().get_normalized_index(idx); - }); - SmallVector seq_dims(dims.size()); - std::iota(seq_dims.begin(), seq_dims.end(), 0); - if (normalized_dims == seq_dims) { - return {ttnn::operations::experimental::auto_format::AutoFormat::move_tensor_to_mem_config( - a, output_mem_config)}; - } - return {permute_impl(a, normalized_dims, output_mem_config, pad_value)}; + return {permute_impl(a, dims, output_mem_config, pad_value)}; }, {a}, output_tensors); return output_tensors.at(0); } -Tensor composite_invoke( - const ttnn::Tensor& input_tensor, - tt::stl::Span dims, - const std::optional& memory_config, - const std::optional& pad_value) { - auto output_tensor = - permute_launch(input_tensor, dims, memory_config.value_or(input_tensor.memory_config()), pad_value); - return output_tensor; +bool is_permute_nop(const ttnn::Tensor& a, tt::stl::Span dims) { + if (a.get_shape().rank() == 1) { + return true; + } + auto normalized_dims = ttnn::SmallVector(dims.begin(), dims.end()); + ttnn::SmallVector seq_dims(dims.size()); + std::iota(seq_dims.begin(), seq_dims.end(), 0); + return normalized_dims == seq_dims; } } // namespace detail @@ -193,23 +175,25 @@ ttnn::Tensor ExecutePermute::invoke( const ttnn::Tensor& input_tensor, tt::stl::Span dims, const std::optional& memory_config, - bool composite, const std::optional& pad_value) { - if (composite) { - return detail::composite_invoke(input_tensor, dims, memory_config, pad_value); - } - - const bool initial_input_tensor_on_device = detail::is_on_device(input_tensor); - const auto input_layout = input_tensor.get_layout(); const auto input_rank = input_tensor.get_logical_shape().rank(); - TT_FATAL( input_rank == dims.size(), "The number of dimensions in the tensor input does not match the length of the desired ordering"); + TT_FATAL(detail::is_on_device(input_tensor), "Tensor must already be on device"); + + SmallVector normalized_dims(dims.size()); + std::transform(dims.begin(), dims.end(), normalized_dims.begin(), [input_tensor](std::int64_t idx) { + return input_tensor.get_logical_shape().get_normalized_index(idx); + }); + if (detail::is_permute_nop(input_tensor, normalized_dims)) { + return ttnn::to_memory_config(input_tensor, memory_config.value_or(input_tensor.memory_config())); + } - auto adjust_order = [](tt::stl::Span dims) { - ttnn::SmallVector new_order; - TT_FATAL(dims.size() <= 4, "Error"); + const auto input_layout = input_tensor.get_layout(); + auto adjust_order = [](tt::stl::Span dims) { + ttnn::SmallVector new_order; + TT_FATAL(dims.size() <= 4, "Minimum rank of tensor required is 4"); int additional_ranks = 4 - dims.size(); for (int i = 0; i < additional_ranks; i++) { new_order.push_back(i); @@ -220,10 +204,10 @@ ttnn::Tensor ExecutePermute::invoke( return new_order; }; auto itensor = (input_tensor.get_logical_shape().rank() < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor; - auto iorder = - dims.size() < 4 ? adjust_order(dims) : dims; // internals of permute_impl already adjust negative indices + auto iorder = normalized_dims.size() < 4 + ? adjust_order(normalized_dims) + : normalized_dims; // internals of permute_impl already adjust negative indices - TT_FATAL(detail::is_on_device(itensor), "Error"); auto output_tensor = detail::permute_launch(itensor, iorder, memory_config.value_or(input_tensor.memory_config()), pad_value); output_tensor = ttnn::to_layout(output_tensor, input_layout, std::nullopt, std::nullopt, (Device*)nullptr); @@ -244,11 +228,6 @@ ttnn::Tensor ExecutePermute::invoke( output_tensor = ttnn::reshape(output_tensor, ttnn::Shape(shape_vec, full_shape_vec)); } - if (initial_input_tensor_on_device and not detail::is_on_device(output_tensor)) { - output_tensor = - ttnn::to_device(output_tensor, input_tensor.device(), memory_config.value_or(input_tensor.memory_config())); - } - return output_tensor; } @@ -257,7 +236,7 @@ ttnn::Tensor ExecutePermute::invoke( tt::stl::Span dims, const std::optional& memory_config, const std::optional& pad_value) { - return invoke(DefaultQueueId, input_tensor, dims, memory_config, true, pad_value); + return invoke(DefaultQueueId, input_tensor, dims, memory_config, pad_value); } ttnn::Tensor ExecutePermute::invoke( diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp index 7f9301b696b..2f13b9c2845 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.hpp @@ -15,7 +15,6 @@ struct ExecutePermute { const ttnn::Tensor& input_tensor, tt::stl::Span dims, const std::optional& memory_config, - bool composite = true, const std::optional& pad_value = 0.0f); static ttnn::Tensor invoke( diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.cpp index 2fbb5c0bcd0..be6adbf880b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute_pybind.cpp @@ -44,7 +44,7 @@ void bind_permute(py::module& module) { const std::optional& memory_config, uint8_t queue_id, const std::optional& pad_value) { - return self(queue_id, input_tensor, dims, memory_config, false, pad_value); + return self(queue_id, input_tensor, dims, memory_config, pad_value); }, py::arg("input_tensor").noconvert(), py::arg("dims"), diff --git a/ttnn/cpp/ttnn/tensor/shape/shape.cpp b/ttnn/cpp/ttnn/tensor/shape/shape.cpp index 7dee5428526..d4a54500c46 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape.cpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape.cpp @@ -7,6 +7,7 @@ #include #include #include "ttnn/tensor/shape/small_vector.hpp" +#include "tt_metal/common/assert.hpp" namespace tt::tt_metal { @@ -20,6 +21,17 @@ uint64_t SimpleShape::volume() const { return std::accumulate(cbegin(), cend(), uint64_t{1}, std::multiplies()); } +const uint32_t SimpleShape::get_normalized_index(std::int64_t index) const { + std::int64_t rank = static_cast(this->rank()); + std::uint64_t normalized_index = index >= 0 ? index : rank + index; + TT_FATAL( + normalized_index >= 0 and normalized_index < rank, + "Index is out of bounds for the rank, should be between 0 and {} however is {}", + rank - 1, + normalized_index); + return normalized_index; +} + std::ostream& operator<<(std::ostream& os, const tt::tt_metal::SimpleShape& shape) { os << "SimpleShape(["; for (size_t i = 0; i < shape.rank(); ++i) { diff --git a/ttnn/cpp/ttnn/tensor/shape/shape.hpp b/ttnn/cpp/ttnn/tensor/shape/shape.hpp index b1661578927..f6f78d35fd5 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape.hpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape.hpp @@ -30,6 +30,8 @@ class SimpleShape final : protected ShapeBase { [[nodiscard]] size_t rank() const; [[nodiscard]] uint64_t volume() const; + const uint32_t get_normalized_index(std::int64_t index) const; + // Needed for reflect / fmt static constexpr auto attribute_names = std::forward_as_tuple("value"); auto attribute_values() const { return std::forward_as_tuple(this->value_); } From d547a86e53a3b1363c86cc3420212a0f3703d035 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Fri, 6 Dec 2024 21:17:21 +0000 Subject: [PATCH 09/20] #0: cleab up debug code and other implementations --- .../unit_tests/operations/test_permute.py | 25 ++- ...permute_interleaved_rm_blocked_generic.cpp | 48 +++-- ...r_permute_interleaved_rm_width_permute.cpp | 82 --------- ...permute_interleaved_rm_blocked_generic.cpp | 98 ++++++---- ...r_permute_interleaved_rm_width_permute.cpp | 169 ------------------ .../device/permute_device_operation.hpp | 22 +-- .../device/permute_program_factory.cpp | 127 ------------- .../data_movement/permute/permute.cpp | 20 +-- 8 files changed, 127 insertions(+), 464 deletions(-) delete mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp delete mode 100644 ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 068b067c1ac..4d9c8197b5a 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -205,8 +205,8 @@ def test_permute_5d_width(shape, perm, memory_config, dtype, device): @pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65), (1, 6, 256, 20, 50), (6, 20, 50, 1, 256)]) @pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1), (1, 3, 4, 0, 2), (3, 0, 4, 1, 2)]) -@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG]) -@pytest.mark.parametrize("dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32]) def test_permute_5d_blocked(shape, perm, memory_config, dtype, device): torch.manual_seed(520) input_a = torch.randn(shape) @@ -233,7 +233,22 @@ def test_permute_nd(device): def test_permute_squeeze(device): - tensor = ttnn.ones((1, 1, 3)) - tensor = ttnn.to_device(tensor, device) + ones = ttnn.ones((1, 1, 3)) + tensor = ttnn.to_device(ones, device) out = ttnn.permute(tensor, (0, 1, 2)) - assert out.shape == (1, 1, 3) + assert_with_pcc(ttnn.to_torch(out), ttnn.to_torch(ones), 0.9999) + + +@pytest.mark.parametrize("shape", [(1, 49, 768)]) +@pytest.mark.parametrize("perm", generate_permutations(3)) +@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) +@pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32]) +def test_permute_3D(shape, perm, layout, memory_config, dtype, device): + torch_tensor = torch.rand(shape, dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(torch_tensor, layout=layout, device=device, dtype=dtype, memory_config=memory_config) + output_tensor = ttnn.permute(input_tensor, perm) + output_tensor = ttnn.to_torch(output_tensor) + torch_output = torch.permute(torch_tensor, perm) + assert torch_output.shape == output_tensor.shape + assert_with_pcc(torch_output, output_tensor, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp index 309aafe0562..15b8b199c3f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp @@ -19,13 +19,14 @@ void kernel_main() { constexpr uint32_t element_size = get_compile_time_arg_val(10); constexpr uint32_t input_tensor_page_size = get_compile_time_arg_val(11); + // Precomputed constants: size of a 32 element block along the W dimension (measured in bytes) constexpr uint32_t w_block_size_bytes = w_block_size * element_size; const uint32_t src_addr = get_arg_val(0); - uint32_t start_block = get_arg_val(1); uint32_t end_block = get_arg_val(2); + // Input shape and strides (excluding W dimension and measured in rows, not bytes) uint32_t input_shape[N], src_strides[N]; for (uint32_t i = 3; i < N + 3; i++) { input_shape[i - 3] = get_arg_val(i); @@ -33,30 +34,43 @@ void kernel_main() { } /** - * num_blocks_total blocks in the tensor which are rows_before X, X blocks, rows after X and W blocks - * collapse rows_before and rows_after into a single rows variable - * rows * X blocks * W blocks = num_blocks_total + * We have a multidimensional tensor: + * - num_blocks_total = (rows * x_blocks * w_blocks) where rows = num_rows / X + * Here, 'rows' represent the combination of all rows before and after X dimension. + * So: rows * X * W_dimension = total number of elements (conceptually). + * + * For each 'block': + * - Compute which w_block and x_block this corresponds to. + * - Then compute which row set (xw_block) we are in. */ + // x_dim is the dimension along which we are reading the tensor, as it's the new W dimension in the output tensor uint32_t X = input_shape[x_dim]; uint32_t X_stride = src_strides[x_dim]; const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = input_tensor_page_size}; - uint32_t curr_addr = src_addr; uint32_t idxs[N]; idxs[N - 1] = 0; + uint32_t non_x_rows = num_rows / X; + for (uint32_t block = start_block; block < end_block; ++block) { - uint32_t w_block = block % w_blocks; - uint32_t rem = block / w_blocks; - uint32_t x_block = rem % x_blocks; - rem = rem / x_blocks; - uint32_t xw_block = rem % (num_rows / X); + // Decompose block into w_block, x_block, and xw_block indices + uint32_t rem = block; + const uint32_t w_block = rem % w_blocks; // Which W block are we in? + rem /= w_blocks; + + const uint32_t x_block = rem % x_blocks; // Which X block? + rem /= x_blocks; + + uint32_t xw_block = rem % (non_x_rows); // Which row set (beyond X dimension)? uint32_t remainder = xw_block; + // Compute X block boundaries uint32_t x_start = x_block * x_block_size; uint32_t x_end = min(x_start + x_block_size, X); + // Compute W block boundaries uint32_t w_start = w_block * w_block_size; uint32_t w_end = min(w_start + w_block_size, input_shape[N - 1]); uint32_t w_offset = w_start * element_size; @@ -64,6 +78,7 @@ void kernel_main() { uint32_t w_read_size_bytes = (w_end - w_start) * element_size; // Map linear index i to multidimensional indices idxs[] + // We skip x_dim when doing this mapping and set it separately later for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension if (d == (int32_t)x_dim) { idxs[d] = 0; // Initialize x_dim to zero (will be set in inner loop) @@ -82,20 +97,27 @@ void kernel_main() { } } + // Reserve space in the circular buffer for the X-block length cb_reserve_back(tt::CBIndex::c_0, x_block_size); uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); + + // We read in 'x_block_len' chunks along the X dimension uint32_t page_offset = 0; // Read along the X dimension for (uint32_t x = x_start; x < x_end; ++x) { - // Set the index for the X dimension - uint32_t idx_x = x; // Compute the address offset for this index - uint64_t addr_offset = base_addr_offset + idx_x * X_stride; + uint64_t addr_offset = base_addr_offset + x * X_stride; uint64_t src_noc_addr = get_noc_addr(addr_offset, s0, w_offset); + + // Perform async read of the current line (w_block_len elements) into L1 noc_async_read(src_noc_addr, src_buffer_l1_addr + page_offset, w_read_size_bytes); + + // Advance output pointer by one page size for next row page_offset += input_cb_page_size; } + // Wait for all async reads to complete before proceeding noc_async_read_barrier(); + // Push the filled block into the circular buffer cb_push_back(tt::CBIndex::c_0, x_block_size); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp deleted file mode 100644 index be32533c2d3..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -#include "debug/dprint.h" - -inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { - volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; - for (uint32_t page = 0; page < npages; ++page) { - DPRINT << start + page << ": "; - for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { - DPRINT << BF16(*ptr) << " "; - } - DPRINT << ENDL(); - } - DPRINT << ENDL(); -} - -void kernel_main() { - constexpr bool src0_is_dram = (bool)get_compile_time_arg_val(0); - constexpr uint32_t N = get_compile_time_arg_val(1); - constexpr uint32_t page_size = get_compile_time_arg_val(2); - constexpr uint32_t num_rows = get_compile_time_arg_val(3); - constexpr uint32_t x_dim = get_compile_time_arg_val(4); - - const uint32_t src_addr = get_arg_val(0); - const DataFormat data_format = get_dataformat(tt::CBIndex::c_0); - - uint32_t input_shape[N], src_strides[N]; - for (uint32_t i = 1; i <= N; i++) { - input_shape[i - 1] = get_arg_val(i); - src_strides[i - 1] = get_arg_val(i + N); - } - - uint32_t X = input_shape[x_dim]; - uint32_t X_stride = src_strides[x_dim]; - - const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = page_size}; - - uint32_t curr_addr = src_addr; - uint32_t idxs[N]; - idxs[N - 1] = 0; - for (uint32_t i = 0; i < num_rows/X; ++i) { - // Map linear index i to multidimensional indices idxs[] - uint32_t remainder = i; - for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension - if (d == (int32_t)x_dim) { - idxs[d] = 0; // Initialize x_dim to zero (will be set in inner loop) - continue; // Skip x_dim during mapping - } - idxs[d] = remainder % input_shape[d]; - remainder /= input_shape[d]; - } - idxs[N - 1] = 0; // Initialize W dimension index to zero if not already set - - // Precompute the base address offset (excluding x_dim) - uint64_t base_addr_offset = 0; - for (uint32_t d = 0; d < N; ++d) { - if (d != x_dim) { - base_addr_offset += idxs[d] * src_strides[d]; - } - } - - cb_reserve_back(tt::CBIndex::c_0, X); - uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); - - // Read along the X dimension - for (uint32_t j = 0; j < X; ++j) { - // Set the index for the X dimension - uint32_t idx_x = j; - // Compute the address offset for this index - uint64_t addr_offset = base_addr_offset + idx_x * X_stride; - uint64_t src_noc_addr = get_noc_addr(addr_offset, s0); - noc_async_read(src_noc_addr, src_buffer_l1_addr, page_size); - src_buffer_l1_addr += page_size; - } - noc_async_read_barrier(); - cb_push_back(tt::CBIndex::c_0, X); - } -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp index 4f214ba303f..910b23b72d5 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp @@ -4,10 +4,10 @@ #include #include "dataflow_api.h" -#include "debug/dprint.h" #include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" void kernel_main() { + // Compile-time constants constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(0); constexpr uint32_t N = get_compile_time_arg_val(1); constexpr uint32_t output_cb_page_size = get_compile_time_arg_val(2); @@ -19,7 +19,7 @@ void kernel_main() { constexpr uint32_t W_stride = get_compile_time_arg_val(7); constexpr uint32_t input_cb_page_size = get_compile_time_arg_val(8); - constexpr uint32_t element_size_bytes = get_compile_time_arg_val(9); + constexpr uint32_t element_size = get_compile_time_arg_val(9); constexpr uint32_t num_blocks_total = get_compile_time_arg_val(10); constexpr uint32_t x_blocks = get_compile_time_arg_val(11); @@ -28,17 +28,27 @@ void kernel_main() { constexpr uint32_t w_block_size = get_compile_time_arg_val(14); constexpr uint32_t W = get_compile_time_arg_val(15); constexpr uint32_t output_tensor_page_size = get_compile_time_arg_val(16); + constexpr uint32_t cb_id_in = tt::CBIndex::c_2; - constexpr uint32_t x_block_size_bytes = x_block_size * element_size_bytes; + // Precompute bytes-per-block along X + constexpr uint32_t x_block_size_bytes = x_block_size * element_size; + + // W dimension is always the last dimension constexpr uint32_t w_dim = N - 1; + // Calculate how many "non_x_rows" we have (these are the combinations of all dimensions except X) + constexpr uint32_t non_x_rows = num_rows / X; + + // Destination base address const uint32_t dst_addr = get_arg_val(0); const uint32_t start_block = get_arg_val(1); const uint32_t end_block = get_arg_val(2); + // Interleaved address configuration for the destination const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = output_tensor_page_size}; + // Input shape, permutation, and destination strides uint32_t input_shape[N], perm[N], dest_strides[N]; for (uint32_t i = 3; i < N + 3; i++) { input_shape[i - 3] = get_arg_val(i); @@ -46,7 +56,8 @@ void kernel_main() { dest_strides[i - 3] = get_arg_val(i + 2 * N); } - // Adjust for the transpose between X and W dimensions + // The source data was transposed between W and X by the previous kernel. + // Adjust input_shape and perm to reflect that swap. tt::data_movement::common::swap_elements(input_shape, x_dim, w_dim); for (uint32_t i = 0; i < N; i++) { if (perm[i] == x_dim) { @@ -56,7 +67,8 @@ void kernel_main() { } } - uint32_t x_dim_in_dest = N; // Invalid index + // Find where the original X dimension ended up in the permuted output + uint32_t x_dim_in_dest = N; // Will hold the position of x_dim in the permuted array for (uint32_t i = 0; i < N; ++i) { if (perm[i] == x_dim) { x_dim_in_dest = i; @@ -66,65 +78,89 @@ void kernel_main() { uint32_t src_multi_idx[N] = {0}; uint32_t dest_multi_idx[N] = {0}; + + // Process each block of data from start_block to end_block for (uint32_t block = start_block; block < end_block; ++block) { - // Compute block indices - uint32_t w_block = block % w_blocks; - uint32_t rem = block / w_blocks; - uint32_t x_block = rem % x_blocks; - rem = rem / x_blocks; - uint32_t xw_block = rem % (num_rows / X); - - // Map linear index xw_block to multidimensional indices idxs[] - uint32_t remainder = xw_block; + // Decompose linear block index into w_block, x_block, and xw_block + uint32_t rem = block; + + // w_block: portion of the W dimension handled by this block + const uint32_t w_block = rem % w_blocks; + rem /= w_blocks; + + // x_block: portion of the X dimension handled by this block + const uint32_t x_block = rem % x_blocks; + rem /= x_blocks; + + // xw_block: which "non-X row set" we are in + const uint32_t xw_block = rem % non_x_rows; - uint32_t x_start = x_block * x_block_size; - uint32_t x_end = min(x_start + x_block_size, X); - uint32_t x_offset = x_start * element_size_bytes; + // Compute start/end boundaries for the current X and W blocks + const uint32_t x_start = x_block * x_block_size; + const uint32_t x_end = min(x_start + x_block_size, X); - uint32_t w_start = w_block * w_block_size; - uint32_t w_end = min(w_start + w_block_size, W); + const uint32_t w_start = w_block * w_block_size; + const uint32_t w_end = min(w_start + w_block_size, W); - uint32_t x_read_size_bytes = (x_end - x_start) * element_size_bytes; + // Compute the read size for the X dimension + const uint32_t x_read_size_bytes = (x_end - x_start) * element_size; + const uint32_t x_offset = x_start * element_size; - // Compute source indices - size_t remaining = xw_block; - for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension + // Decode xw_block into multi-dimensional indices excluding the W dimension and X dimension + uint32_t remainder = xw_block; + for (int32_t d = N - 2; d >= 0; --d) { if (d == (int32_t)x_dim) { - continue; // Skip x_dim + // Skip the original X dimension index during this mapping + continue; } - src_multi_idx[d] = remaining % input_shape[d]; - remaining /= input_shape[d]; + src_multi_idx[d] = remainder % input_shape[d]; + remainder /= input_shape[d]; } - // Precompute dest_multi_idx and dest_linear_idx_base + // Compute dest_multi_idx (excluding W dimension), and a base linear index + // for all dimensions except W and X. We'll add W and X offsets later. uint32_t dest_linear_idx_base = 0; for (uint32_t i = 0; i < N; ++i) { uint32_t src_idx = perm[i]; if (src_idx != x_dim) { dest_multi_idx[i] = src_multi_idx[src_idx]; - if (i < N - 1) { // Exclude W dimension + // Accumulate partial index product for all dimensions except W + if (i < w_dim) { dest_linear_idx_base += dest_multi_idx[i] * dest_strides[i]; } } } - // Wait for transposed block + // Wait for the transposed block data to be ready in the input CB cb_wait_front(cb_id_in, w_block_size); uint32_t transposed_buffer_read_addr = get_read_ptr(cb_id_in); + + // Iterate over the W dimension elements for (uint32_t w = w_start; w < w_end; ++w) { + // Update indices for the current W src_multi_idx[x_dim] = w; dest_multi_idx[x_dim_in_dest] = w; - // Update dest_linear_idx + // Compute final linear index for the current W uint32_t dest_linear_idx = dest_linear_idx_base; - if (x_dim_in_dest < N - 1) { // Exclude W dimension + if (x_dim_in_dest < w_dim) { dest_linear_idx += dest_multi_idx[x_dim_in_dest] * dest_strides[x_dim_in_dest]; } + + // Compute the NoC address for the output uint64_t dst_noc_addr = get_noc_addr(dest_linear_idx, s0, x_offset); + + // Compute the L1 address from which to write (offset by W-block pages) uint32_t l1_addr = transposed_buffer_read_addr + (w - w_start) * output_cb_page_size; + + // Perform an asynchronous write of the X-block to the destination noc_async_write(l1_addr, dst_noc_addr, x_read_size_bytes); } + + // Wait until all writes are completed before proceeding to the next block noc_async_write_barrier(); + + // Pop the block from the input circular buffer, as we're done writing it cb_pop_front(cb_id_in, w_block_size); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp deleted file mode 100644 index 552ae483651..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp +++ /dev/null @@ -1,169 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" -#include "debug/dprint.h" -#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" - -// Function template to swap two elements in a uint32_t array -template -FORCE_INLINE void swap_elements(uint32_t (&array)[N], size_t i, size_t j) { - // Perform the swap - uint32_t temp = array[i]; - array[i] = array[j]; - array[j] = temp; -} - -inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { - volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; - for (uint32_t page = 0; page < npages; ++page) { - DPRINT << start + page << ": "; - for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { - DPRINT << BF16(*ptr) << " "; - } - DPRINT << ENDL(); - } - DPRINT << ENDL(); -} - -FORCE_INLINE void transpose_XW_to_WX(uint32_t input_l1_addr, uint32_t output_l1_addr, uint32_t X, uint32_t W, uint32_t element_size, uint32_t input_page_size, uint32_t output_page_size) { - volatile tt_l1_ptr uint8_t* input_ptr = reinterpret_cast(input_l1_addr); - volatile tt_l1_ptr uint8_t* output_ptr = reinterpret_cast(output_l1_addr); - // transpose from XW, where X is outer and W inner, to WX, where W is outer and X is inner - // each element is element_size bytes - // each row is W elements, and each row is separated by input_page_size bytes - // each output row is X elements, and each row is separated by output_page_size bytes - - for (uint32_t x = 0; x < X; ++x) { - for (uint32_t w = 0; w < W; ++w) { - // Compute the input and output addresses - uint32_t input_addr = x * input_page_size + w * element_size; - uint32_t output_addr = w * output_page_size + x * element_size; - // Copy the element - do we have memcpy? use this for now - for (uint32_t i = 0; i < element_size; ++i) { - output_ptr[output_addr + i] = input_ptr[input_addr + i]; - } - } - } -} - -void kernel_main() { - constexpr bool dst_is_dram = (bool)get_compile_time_arg_val(0); - constexpr uint32_t N = get_compile_time_arg_val(1); - constexpr uint32_t output_page_size = get_compile_time_arg_val(2); - constexpr uint32_t num_rows = get_compile_time_arg_val(3); - - constexpr uint32_t X = get_compile_time_arg_val(4); - constexpr uint32_t X_stride = get_compile_time_arg_val(5); - constexpr uint32_t x_dim = get_compile_time_arg_val(6); - - constexpr uint32_t W = get_compile_time_arg_val(7); - constexpr uint32_t W_stride = get_compile_time_arg_val(8); - constexpr uint32_t input_page_size = get_compile_time_arg_val(9); - constexpr uint32_t element_size_bytes = get_compile_time_arg_val(10); - - constexpr uint32_t w_dim = N - 1; - - const uint32_t dst_addr = get_arg_val(0); - - DPRINT << "N = " << N << ENDL(); - DPRINT << "page_size = " << output_page_size << ENDL(); - DPRINT << "num_rows = " << num_rows << ENDL(); - DPRINT << "x_dim = " << x_dim << ENDL(); - DPRINT << "X = " << X << ENDL(); - DPRINT << "X_stride = " << X_stride << ENDL(); - DPRINT << "x_dim = " << x_dim << ENDL(); - DPRINT << "W = " << W << ENDL(); - DPRINT << "W_stride = " << W_stride << ENDL(); - DPRINT << "input_page_size = " << input_page_size << ENDL(); - DPRINT << "element_size_bytes = " << element_size_bytes << ENDL(); - DPRINT << "w_dim = " << w_dim << ENDL(); - DPRINT << "dst_addr = " << dst_addr << ENDL(); - - const InterleavedAddrGen s0 = { - .bank_base_address = dst_addr, - .page_size = output_page_size - }; - - uint32_t input_shape[N], perm[N], dest_strides[N]; - for (uint32_t i = 1; i <= N; i++) { - input_shape[i - 1] = get_arg_val(i); - perm[i - 1] = get_arg_val(i + N); - dest_strides[i - 1] = get_arg_val(i + 2 * N); - } - - // Adjust for the transpose between X and W dimensions - swap_elements(input_shape, x_dim, w_dim); - for (uint32_t i = 0; i < N; i++) { - if (perm[i] == x_dim) { - perm[i] = w_dim; - } else if (perm[i] == w_dim) { - perm[i] = x_dim; - } - } - - uint32_t x_dim_in_dest = N; // Invalid index - for (uint32_t i = 0; i < N; ++i) { - if (perm[i] == x_dim) { - x_dim_in_dest = i; - break; - } - } - uint32_t transposed_buffer_read_addr = get_read_ptr(tt::CBIndex::c_1); - uint32_t src_multi_idx[N] = {0}; - uint32_t dest_multi_idx[N] = {0}; - for (uint32_t block = 0; block < num_rows / X; ++block) { - // Compute source indices - size_t remaining = block; - for (int32_t d = N - 2; d >= 0; --d) { // Exclude W dimension - if (d == (int32_t)x_dim) { - continue; // Skip x_dim - } - src_multi_idx[d] = remaining % input_shape[d]; - remaining /= input_shape[d]; - } - - // Precompute dest_multi_idx and dest_linear_idx_base - uint32_t dest_linear_idx_base = 0; - for (uint32_t i = 0; i < N; ++i) { - uint32_t src_idx = perm[i]; - if (src_idx != x_dim) { - dest_multi_idx[i] = src_multi_idx[src_idx]; - if (i < N - 1) { // Exclude W dimension - dest_linear_idx_base += dest_multi_idx[i] * dest_strides[i]; - } - } - } - - cb_wait_front(tt::CBIndex::c_0, X); - uint32_t src_buffer_l1_addr = get_read_ptr(tt::CBIndex::c_0); - - // Transpose the block - transpose_XW_to_WX(src_buffer_l1_addr, transposed_buffer_read_addr, X, W, element_size_bytes, input_page_size, output_page_size); - - - - // Update only the changing components inside the loop - for (uint32_t w = 0; w < W; ++w) { - src_multi_idx[x_dim] = w; - dest_multi_idx[x_dim_in_dest] = w; - // for (uint32_t i = 0; i < N; ++i) { - // DPRINT << "dest_multi_idx[" << i << "] = " << dest_multi_idx[i] << " "; - // } - // DPRINT << ENDL(); - // Update dest_linear_idx - uint32_t dest_linear_idx = dest_linear_idx_base; - if (x_dim_in_dest < N - 1) { // Exclude W dimension - dest_linear_idx += dest_multi_idx[x_dim_in_dest] * dest_strides[x_dim_in_dest]; - } - DPRINT << "dest_linear_idx = " << dest_linear_idx << ENDL(); - uint64_t dst_noc_addr = get_noc_addr(dest_linear_idx, s0); - noc_async_write(transposed_buffer_read_addr + w * output_page_size, dst_noc_addr, output_page_size); - DPRINT << ENDL(); - } - noc_async_write_barrier(); - cb_pop_front(tt::CBIndex::c_0, X); - } -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index a4377fcf521..36f9328688d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -50,26 +50,6 @@ struct PermuteDeviceOperation { tensor_return_value_t& tensor_return_value); }; - struct SingleCoreWidthPermute { - // Shared variables are the variables that are shared between the create and override_runtime_arguments methods - struct shared_variables_t { - KernelHandle unary_reader_kernel_id; - KernelHandle unary_writer_kernel_id; - }; - using cached_program_t = ttnn::device_operation::CachedProgram; - - static cached_program_t create( - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& tensor_return_value); - - static void override_runtime_arguments( - cached_program_t& cached_program, - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& tensor_return_value); - }; - struct MultiCoreBlockedGeneric { // Shared variables are the variables that are shared between the create and override_runtime_arguments methods struct shared_variables_t { @@ -92,7 +72,7 @@ struct PermuteDeviceOperation { tensor_return_value_t& tensor_return_value); }; - using program_factory_t = std::variant; + using program_factory_t = std::variant; // Mandatory methods diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 9e0765e498a..52534178ea2 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -133,133 +133,6 @@ void PermuteDeviceOperation::SingleCore::override_runtime_arguments( } } -PermuteDeviceOperation::SingleCoreWidthPermute::cached_program_t PermuteDeviceOperation::SingleCoreWidthPermute::create( - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& tensor_return_value) { - using namespace tt; - using namespace tt::tt_metal; - - const auto& input_tensor = tensor_args.input_tensor; - auto& output_tensor = tensor_return_value; - - auto src_buffer = input_tensor.buffer(); - auto dst_buffer = output_tensor.buffer(); - - tt::tt_metal::Program program{}; - - tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); - uint32_t input_rm_page_size = detail::page_size(input_tensor); - - tt::DataFormat cb_data_format_output = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); - uint32_t output_rm_page_size = detail::page_size(tensor_return_value); - - uint32_t num_input_pages = detail::num_pages(input_tensor); - - tt::tt_metal::Device* device = input_tensor.device(); - - uint32_t src0_cb_index = tt::CBIndex::c_0; - uint32_t src1_cb_index = tt::CBIndex::c_1; - uint32_t num_input_pages_to_read = 2; - - // we are focused on reading one row at a time, in a pattern that allows us to write an entire output row at a time - // if W is being swapped with another dim X (e.g. H), then we need to read X rows at a time (X is the new row dimension) - // CB is thus X pages in size (X*W*element_size) - // we read in X input rows of size W, and write out W output rows of size X - // find the new row dimension (X) - - uint32_t x_dim = operation_attributes.dims.back(); - uint32_t X = input_tensor.get_logical_shape()[x_dim]; - // stride from one row to the next for each dim in the input tensor - auto input_strides = detail::get_row_strides(input_tensor.get_logical_shape()); - uint32_t X_stride = input_strides[x_dim]; - - auto output_strides = detail::get_row_strides(output_tensor.get_logical_shape()); - // after we transpose X and W, we need to stride from one row to the next for each dim in the output tensor - uint32_t W = input_tensor.get_logical_shape()[-1]; - uint32_t W_stride = output_strides[x_dim]; - - CoreRange core({0, 0}, {0, 0}); - tt::tt_metal::CircularBufferConfig cb_src0_config = - tt::tt_metal::CircularBufferConfig( - num_input_pages_to_read * input_rm_page_size * X, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, input_rm_page_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); - - tt::tt_metal::CircularBufferConfig cb_src1_config = - tt::tt_metal::CircularBufferConfig( - num_input_pages_to_read * output_rm_page_size * W, {{src1_cb_index, cb_data_format}}) - .set_page_size(src1_cb_index, output_rm_page_size); - auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src1_config); - - uint32_t N = operation_attributes.dims.size(); - uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; - - bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector reader_compile_time_args = {(uint32_t)src_is_dram, N, input_rm_page_size, num_rows, x_dim}; - - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_width_permute.cpp", - core, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); - - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector writer_compile_time_args = {(std::uint32_t)dst_is_dram, N, output_rm_page_size, num_rows, X, X_stride, x_dim, W, W_stride, input_rm_page_size, input_tensor.element_size()}; - tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_width_permute.cpp", - core, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - auto input_shape_view = input_tensor.get_logical_shape().view(); - - std::vector reader_runtime_args = {src_buffer->address()}; - reader_runtime_args.insert(reader_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); - reader_runtime_args.insert(reader_runtime_args.end(), input_strides.begin(), input_strides.end()); - - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); - - - std::vector writer_runtime_args = {dst_buffer->address()}; - writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); - writer_runtime_args.insert( - writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); - writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); - - tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); - - return { - std::move(program), - {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; -} - -void PermuteDeviceOperation::SingleCoreWidthPermute::override_runtime_arguments( - cached_program_t& cached_program, - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& tensor_return_value) { - auto& program = cached_program.program; - auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; - auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; - - const auto& input_tensor = tensor_args.input_tensor; - auto& output_tensor = tensor_return_value; - - auto src_buffer = input_tensor.buffer(); - auto dst_buffer = output_tensor.buffer(); - - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0}); - runtime_args[0] = src_buffer->address(); - } - - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0}); - runtime_args[0] = dst_buffer->address(); - } -} - PermuteDeviceOperation::MultiCoreBlockedGeneric::cached_program_t PermuteDeviceOperation::MultiCoreBlockedGeneric::create( const operation_attributes_t& operation_attributes, diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp index a33e426d5e3..98768310aa2 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp @@ -190,6 +190,8 @@ ttnn::Tensor ExecutePermute::invoke( return ttnn::to_memory_config(input_tensor, memory_config.value_or(input_tensor.memory_config())); } + auto padded_shape = input_tensor.get_padded_shape(); + const auto input_layout = input_tensor.get_layout(); auto adjust_order = [](tt::stl::Span dims) { ttnn::SmallVector new_order; @@ -204,28 +206,14 @@ ttnn::Tensor ExecutePermute::invoke( return new_order; }; auto itensor = (input_tensor.get_logical_shape().rank() < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor; - auto iorder = normalized_dims.size() < 4 - ? adjust_order(normalized_dims) - : normalized_dims; // internals of permute_impl already adjust negative indices + auto iorder = normalized_dims.size() < 4 ? adjust_order(normalized_dims) : normalized_dims; auto output_tensor = detail::permute_launch(itensor, iorder, memory_config.value_or(input_tensor.memory_config()), pad_value); output_tensor = ttnn::to_layout(output_tensor, input_layout, std::nullopt, std::nullopt, (Device*)nullptr); if (input_rank < 4) { - const auto shape = output_tensor.get_shape(); - const auto full_shape = output_tensor.get_shape().with_tile_padding(); - SmallVector shape_vec{}; - SmallVector full_shape_vec{}; - int i = 0; - while (i < 3 and shape[i] == 1) { - i++; - } - for (; i < shape.rank(); i++) { - shape_vec.push_back(shape[i]); - full_shape_vec.push_back(full_shape[i]); - } - output_tensor = ttnn::reshape(output_tensor, ttnn::Shape(shape_vec, full_shape_vec)); + output_tensor = ttnn::squeeze_from_4D(output_tensor, input_rank); } return output_tensor; From 1cf08a9bd24d68f54424ad7e7c2a01996d86744f Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 10 Dec 2024 02:51:30 +0000 Subject: [PATCH 10/20] #0: make transpose use prim permute in RM workaround cases --- .../unit_testing/misc/test_transpose.py | 6 +-- .../device/permute_program_factory.cpp | 2 - .../data_movement/transpose/transpose.cpp | 50 +++++-------------- 3 files changed, 15 insertions(+), 43 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index ef89d551683..3b85e37b15e 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -369,9 +369,7 @@ def run_tranpose_hw_rm_program_cache(device, n, c, h, w, use_program_cache): memory_config=ttnn.L1_MEMORY_CONFIG, ) activation_pyt_padded = ttnn.transpose(activation_pyt_padded, 2, 3, memory_config=ttnn.L1_MEMORY_CONFIG) - activation_pyt_padded_out = ttnn.to_memory_config(activation_pyt_padded, ttnn.L1_MEMORY_CONFIG) - activation_pyt_padded_out = ttnn.from_device(activation_pyt_padded_out) - activation_pyt_padded_out = ttnn.to_torch(activation_pyt_padded_out) + activation_pyt_padded_out = ttnn.to_torch(activation_pyt_padded) assert_with_pcc(torch_output_tensor, activation_pyt_padded_out, 0.9999) @@ -384,7 +382,7 @@ def run_tranpose_hw_rm_program_cache(device, n, c, h, w, use_program_cache): def test_tranpose_hw_rm_with_program_cache(device, n, c, h, w, use_program_cache): for _ in range(2): run_tranpose_hw_rm_program_cache(device, n, c, h, w, use_program_cache) - # dummy tensor to change tensor alloc + # # dummy tensor to change tensor alloc dummy_shape = [1, 1, 32, 32] py_dummy_tensor = torch.randn(dummy_shape) tt_dummy_tensor = ttnn.from_torch( diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 52534178ea2..0d726d45be0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -344,8 +344,6 @@ void PermuteDeviceOperation::MultiCoreBlockedGeneric::override_runtime_arguments runtime_args[0] = src_buffer->address(); auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); runtime_args_writer[0] = dst_buffer->address(); - auto& runtime_args_compute = tt::tt_metal::GetRuntimeArgs(program, compute_kernel_id, core); - runtime_args_compute[0] = dst_buffer->address(); } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp index c7a20848999..1db242d5e7d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp @@ -11,6 +11,7 @@ #include "ttnn/cpp/ttnn/operations/copy.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp" // FIXME: ARCH_NAME specific include #include "noc/noc_parameters.h" // DRAM_ALIGNMENT @@ -56,20 +57,14 @@ inline Tensor transpose_( TransposeOpDim transpose_dim, const MemoryConfig& output_mem_config, const std::optional& pad_value) { - bool tiled_only = false; - constexpr uint32_t FACE_WIDTH = - tt::constants::FACE_WIDTH; // this is a highly restrictive constraint on the RM transpose_wh kernel, and with - // all the other bugs/limitations we should rewrite it - // use device->get_allocator_alignment when the it reflects the alignment of the buffer and doesn't just default to - // DRAM auto BUFFER_ALIGNMENT = a.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? DRAM_ALIGNMENT : L1_ALIGNMENT; uint32_t W = a.get_padded_shape()[-1]; uint32_t H = a.get_padded_shape()[-2]; switch (transpose_dim) { case TransposeOpDim::HC: - tiled_only = a.get_layout() == Layout::TILE; - if ((!tiled_only) && ((W * a.element_size()) % BUFFER_ALIGNMENT != 0)) { // - tiled_only = true; + if ((a.get_layout() == Layout::ROW_MAJOR) && ((W * a.element_size()) % BUFFER_ALIGNMENT != 0)) { // + return ttnn::prim::permute( + (const ttnn::Tensor)a, ttnn::SmallVector({0, 2, 1, 3}), output_mem_config, std::nullopt); } break; // bubble dim around to make it possible as these implementations don't have a kernel @@ -83,39 +78,20 @@ inline Tensor transpose_( return ttnn::permute( (const ttnn::Tensor)a, ttnn::SmallVector({0, 3, 2, 1}), output_mem_config, pad_value); case TransposeOpDim::CN: - tiled_only = true; // CN only has a tiled implementation at the moment + if (a.get_layout() == Layout::ROW_MAJOR) { + return ttnn::prim::permute( + (const ttnn::Tensor)a, ttnn::SmallVector({1, 0, 2, 3}), output_mem_config, std::nullopt); + } break; - case TransposeOpDim::WH: // THIS NEEDS TO BE FIXED - if (((W * a.element_size()) % FACE_WIDTH != 0) || ((H * a.element_size()) % FACE_WIDTH != 0)) { - tiled_only = true; - } else if (a.device()->arch() == tt::ARCH::GRAYSKULL) { - tiled_only = a.shape()[-2] > 256; // hangs right now past this dimension, #13660 will turn it from a - // hang into a PCC issue for GS and improve perf for WH - } else if ( - !a.is_sharded() && a.layout() == Layout::ROW_MAJOR && - !rm_enough_available_space( - a)) { // rm is L1 intensive, if it overflows we can do tiled which allocates much smaller CBs - tiled_only = true; + case TransposeOpDim::WH: + if (!a.is_sharded() && a.layout() == Layout::ROW_MAJOR) { + return ttnn::prim::permute( + (const ttnn::Tensor)a, ttnn::SmallVector({0, 1, 3, 2}), output_mem_config, std::nullopt); } break; default: break; } - if (a.get_layout() == Layout::ROW_MAJOR) { - // the assorted cases where only tiled works right now (HC with stick width constraint, WH with stick width - // constraint, CN). - if (tiled_only) { - // convert to tiled - Tensor b = ttnn::to_layout(a, Layout::TILE, std::nullopt, std::nullopt, (Device*)nullptr); - // run the transpose. - b = operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {b}).at(0); - // back to original layout - b = ttnn::to_layout(b, a.get_layout(), std::nullopt, std::nullopt, (Device*)nullptr); - return b; - } - return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); - } else { - return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); - } + return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); } ttnn::Tensor transpose_nd( From bac01ce3a8c891978f31ae0b017d8c6b862da9f8 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 10 Dec 2024 05:36:27 +0000 Subject: [PATCH 11/20] #0: make row-invariant permute kernel multicore --- ..._permute_interleaved_rm_row_invariant.cpp} | 6 +- ..._permute_interleaved_rm_row_invariant.cpp} | 12 ++-- .../device/permute_device_operation.cpp | 6 +- .../device/permute_device_operation.hpp | 5 +- .../device/permute_program_factory.cpp | 65 +++++++++++++------ 5 files changed, 62 insertions(+), 32 deletions(-) rename ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/{reader_permute_interleaved_rm.cpp => reader_permute_interleaved_rm_row_invariant.cpp} (78%) rename ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/{writer_permute_interleaved_rm.cpp => writer_permute_interleaved_rm_row_invariant.cpp} (83%) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp similarity index 78% rename from ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp rename to ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp index b5ffc12cf7d..93a42f81325 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_row_invariant.cpp @@ -12,14 +12,16 @@ void kernel_main() { constexpr uint32_t num_rows = get_compile_time_arg_val(3); const uint32_t src_addr = get_arg_val(0); + const uint32_t start_row = get_arg_val(1); + const uint32_t end_row = get_arg_val(2); const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = page_size}; uint32_t curr_addr = src_addr; - for (uint32_t i = 0; i < num_rows; ++i) { + for (uint32_t row = start_row; row < end_row; ++row) { cb_reserve_back(tt::CBIndex::c_0, 1); uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); - noc_async_read_page(i, s0, src_buffer_l1_addr); + noc_async_read_page(row, s0, src_buffer_l1_addr); noc_async_read_barrier(); cb_push_back(tt::CBIndex::c_0, 1); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp similarity index 83% rename from ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp rename to ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp index 34be75dfdf4..46903375ff6 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp @@ -12,19 +12,21 @@ void kernel_main() { constexpr uint32_t num_rows = get_compile_time_arg_val(3); const uint32_t dst_addr = get_arg_val(0); + const uint32_t start_row = get_arg_val(1); + const uint32_t end_row = get_arg_val(2); const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = page_size}; uint32_t input_shape[N], perm[N], dest_strides[N]; - for (uint32_t i = 1; i <= N; i++) { - input_shape[i - 1] = get_arg_val(i); - perm[i - 1] = get_arg_val(i + N); - dest_strides[i - 1] = get_arg_val(i + 2 * N); + for (uint32_t i = 3; i < N + 3; i++) { + input_shape[i - 3] = get_arg_val(i); + perm[i - 3] = get_arg_val(i + N); + dest_strides[i - 3] = get_arg_val(i + 2 * N); } uint32_t src_buffer_l1_addr = get_write_ptr(tt::CBIndex::c_0); uint32_t curr_addr = dst_addr; - for (uint32_t row = 0; row < num_rows; ++row) { + for (uint32_t row = start_row; row < end_row; ++row) { // Compute multi-dimensional index for the source row uint32_t src_multi_idx[N]; size_t remaining = row; diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp index af6ee177fca..8bc4bece3b0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp @@ -12,9 +12,11 @@ namespace ttnn::operations::data_movement { PermuteDeviceOperation::program_factory_t PermuteDeviceOperation::select_program_factory( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + // If the last dimension is not permuted, we can use the row-invariant kernel if (operation_attributes.dims.back() == tensor_args.input_tensor.get_logical_shape().rank() - 1) { - return SingleCore{}; + return MultiCoreRowInvariant{}; } + // Otherwise, we need to use the blocked generic, row moving kernel return MultiCoreBlockedGeneric{}; } @@ -33,7 +35,7 @@ void PermuteDeviceOperation::validate_on_program_cache_hit( PermuteDeviceOperation::shape_return_value_t PermuteDeviceOperation::compute_output_shapes( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { - SmallVector shape, padded_shape; + SmallVector shape; auto input_shape = tensor_args.input_tensor.get_logical_shape(); shape.reserve(input_shape.rank()); for (auto dim : attributes.dims) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index 36f9328688d..e27b8251bdc 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -30,11 +30,12 @@ struct PermuteDeviceOperation { using tensor_return_value_t = Tensor; - struct SingleCore { + struct MultiCoreRowInvariant { // Shared variables are the variables that are shared between the create and override_runtime_arguments methods struct shared_variables_t { KernelHandle unary_reader_kernel_id; KernelHandle unary_writer_kernel_id; + CoreRangeSet all_cores; }; using cached_program_t = ttnn::device_operation::CachedProgram; @@ -72,7 +73,7 @@ struct PermuteDeviceOperation { tensor_return_value_t& tensor_return_value); }; - using program_factory_t = std::variant; + using program_factory_t = std::variant; // Mandatory methods diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 0d726d45be0..5ead48cb7cb 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -32,7 +32,7 @@ std::vector get_row_strides(const ttnn::SimpleShape& shape) { } // namespace detail -PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::SingleCore::create( +PermuteDeviceOperation::MultiCoreRowInvariant::cached_program_t PermuteDeviceOperation::MultiCoreRowInvariant::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& tensor_return_value) { @@ -60,54 +60,78 @@ PermuteDeviceOperation::SingleCore::cached_program_t PermuteDeviceOperation::Sin uint32_t src0_cb_index = tt::CBIndex::c_0; uint32_t num_input_pages_to_read = 2; - CoreRange core({0, 0}, {0, 0}); + uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; + + auto compute_with_storage_grid_size = input_tensor.device()->compute_with_storage_grid_size(); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_rows); + tt::tt_metal::CircularBufferConfig cb_src0_config = tt::tt_metal::CircularBufferConfig( num_input_pages_to_read * input_rm_page_size, {{src0_cb_index, cb_data_format}}) .set_page_size(src0_cb_index, input_rm_page_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); uint32_t N = operation_attributes.dims.size(); - uint32_t num_rows = input_tensor.volume() / input_tensor.get_logical_shape()[-1]; bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = {(uint32_t)src_is_dram, N, input_rm_page_size, num_rows}; tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm.cpp", - core, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "reader_permute_interleaved_rm_row_invariant.cpp", + all_cores, tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args)); bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; std::vector writer_compile_time_args = {(std::uint32_t)dst_is_dram, N, output_rm_page_size, num_rows}; tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm.cpp", - core, + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/" + "writer_permute_interleaved_rm_row_invariant.cpp", + all_cores, tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - std::vector reader_runtime_args = {src_buffer->address()}; - - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + std::vector reader_runtime_args = {src_buffer->address(), 0, 0}; auto input_shape_view = input_tensor.get_logical_shape().view(); auto output_strides = detail::get_row_strides(output_tensor.get_logical_shape()); // in anticipation of RM padding - std::vector writer_runtime_args = {dst_buffer->address()}; + std::vector writer_runtime_args = {dst_buffer->address(), 0, 0}; writer_runtime_args.insert(writer_runtime_args.end(), input_shape_view.begin(), input_shape_view.end()); writer_runtime_args.insert( writer_runtime_args.end(), operation_attributes.dims.begin(), operation_attributes.dims.end()); writer_runtime_args.insert(writer_runtime_args.end(), output_strides.begin(), output_strides.end()); - tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + auto cores = corerange_to_cores(all_cores, std::nullopt); + uint32_t start_row = 0; + uint32_t num_rows_per_core = 0; + for (const auto& core : cores) { + if (core_group_1.contains(core)) { + num_rows_per_core = num_tiles_per_core_group_1; + } else if (core_group_2.contains(core)) { + num_rows_per_core = num_tiles_per_core_group_2; + } else { + // no-op + num_rows_per_core = 0; + } + uint32_t end_row = start_row + num_rows_per_core; + reader_runtime_args[1] = start_row; + reader_runtime_args[2] = end_row; + writer_runtime_args[1] = start_row; + writer_runtime_args[2] = end_row; + tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_runtime_args); + tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_runtime_args); + start_row = end_row; + } return { std::move(program), {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id}}; } -void PermuteDeviceOperation::SingleCore::override_runtime_arguments( +void PermuteDeviceOperation::MultiCoreRowInvariant::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, @@ -121,15 +145,14 @@ void PermuteDeviceOperation::SingleCore::override_runtime_arguments( auto src_buffer = input_tensor.buffer(); auto dst_buffer = output_tensor.buffer(); + auto& all_cores = cached_program.shared_variables.all_cores; - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0}); + auto cores = corerange_to_cores(all_cores, std::nullopt); + for (const auto& core : cores) { + auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, core); runtime_args[0] = src_buffer->address(); - } - - { - auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0}); - runtime_args[0] = dst_buffer->address(); + auto& runtime_args_writer = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args_writer[0] = dst_buffer->address(); } } From 12858fb3b6a41f98c315f5883c70a28898e27a91 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 10 Dec 2024 17:54:48 +0000 Subject: [PATCH 12/20] #0: revert transpose changes for now --- .../data_movement/transpose/transpose.cpp | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp index 1db242d5e7d..c7a20848999 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/transpose.cpp @@ -11,7 +11,6 @@ #include "ttnn/cpp/ttnn/operations/copy.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/slice/slice.hpp" -#include "ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp" // FIXME: ARCH_NAME specific include #include "noc/noc_parameters.h" // DRAM_ALIGNMENT @@ -57,14 +56,20 @@ inline Tensor transpose_( TransposeOpDim transpose_dim, const MemoryConfig& output_mem_config, const std::optional& pad_value) { + bool tiled_only = false; + constexpr uint32_t FACE_WIDTH = + tt::constants::FACE_WIDTH; // this is a highly restrictive constraint on the RM transpose_wh kernel, and with + // all the other bugs/limitations we should rewrite it + // use device->get_allocator_alignment when the it reflects the alignment of the buffer and doesn't just default to + // DRAM auto BUFFER_ALIGNMENT = a.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? DRAM_ALIGNMENT : L1_ALIGNMENT; uint32_t W = a.get_padded_shape()[-1]; uint32_t H = a.get_padded_shape()[-2]; switch (transpose_dim) { case TransposeOpDim::HC: - if ((a.get_layout() == Layout::ROW_MAJOR) && ((W * a.element_size()) % BUFFER_ALIGNMENT != 0)) { // - return ttnn::prim::permute( - (const ttnn::Tensor)a, ttnn::SmallVector({0, 2, 1, 3}), output_mem_config, std::nullopt); + tiled_only = a.get_layout() == Layout::TILE; + if ((!tiled_only) && ((W * a.element_size()) % BUFFER_ALIGNMENT != 0)) { // + tiled_only = true; } break; // bubble dim around to make it possible as these implementations don't have a kernel @@ -78,20 +83,39 @@ inline Tensor transpose_( return ttnn::permute( (const ttnn::Tensor)a, ttnn::SmallVector({0, 3, 2, 1}), output_mem_config, pad_value); case TransposeOpDim::CN: - if (a.get_layout() == Layout::ROW_MAJOR) { - return ttnn::prim::permute( - (const ttnn::Tensor)a, ttnn::SmallVector({1, 0, 2, 3}), output_mem_config, std::nullopt); - } + tiled_only = true; // CN only has a tiled implementation at the moment break; - case TransposeOpDim::WH: - if (!a.is_sharded() && a.layout() == Layout::ROW_MAJOR) { - return ttnn::prim::permute( - (const ttnn::Tensor)a, ttnn::SmallVector({0, 1, 3, 2}), output_mem_config, std::nullopt); + case TransposeOpDim::WH: // THIS NEEDS TO BE FIXED + if (((W * a.element_size()) % FACE_WIDTH != 0) || ((H * a.element_size()) % FACE_WIDTH != 0)) { + tiled_only = true; + } else if (a.device()->arch() == tt::ARCH::GRAYSKULL) { + tiled_only = a.shape()[-2] > 256; // hangs right now past this dimension, #13660 will turn it from a + // hang into a PCC issue for GS and improve perf for WH + } else if ( + !a.is_sharded() && a.layout() == Layout::ROW_MAJOR && + !rm_enough_available_space( + a)) { // rm is L1 intensive, if it overflows we can do tiled which allocates much smaller CBs + tiled_only = true; } break; default: break; } - return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); + if (a.get_layout() == Layout::ROW_MAJOR) { + // the assorted cases where only tiled works right now (HC with stick width constraint, WH with stick width + // constraint, CN). + if (tiled_only) { + // convert to tiled + Tensor b = ttnn::to_layout(a, Layout::TILE, std::nullopt, std::nullopt, (Device*)nullptr); + // run the transpose. + b = operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {b}).at(0); + // back to original layout + b = ttnn::to_layout(b, a.get_layout(), std::nullopt, std::nullopt, (Device*)nullptr); + return b; + } + return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); + } else { + return operation::run(Transpose{transpose_dim, output_mem_config, pad_value}, {a}).at(0); + } } ttnn::Tensor transpose_nd( From 1baf18ed44b0aa4df3ad1e17b3f437cea31695da Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 10 Dec 2024 19:55:58 +0000 Subject: [PATCH 13/20] #0: disable tests on GS due to bad pcc --- tests/ttnn/unit_tests/operations/test_permute.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 4d9c8197b5a..158ad3515c4 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -10,7 +10,7 @@ import itertools from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import is_blackhole +from models.utility_functions import is_blackhole, is_grayskull, skip_for_grayskull @pytest.mark.parametrize("h", [32]) @@ -185,6 +185,7 @@ def generate_permutations(N): yield perm +@skip_for_grayskull("tilize_block gives bad pcc after second iteration") @pytest.mark.parametrize("shape", [(7, 7, 7, 7, 7)]) @pytest.mark.parametrize("perm", generate_permutations(5)) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) @@ -203,6 +204,7 @@ def test_permute_5d_width(shape, perm, memory_config, dtype, device): assert_with_pcc(torch_output, tt_output, 0.9999) +@skip_for_grayskull("tilize_block gives bad pcc after second iteration") @pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65), (1, 6, 256, 20, 50), (6, 20, 50, 1, 256)]) @pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1), (1, 3, 4, 0, 2), (3, 0, 4, 1, 2)]) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) @@ -223,6 +225,7 @@ def test_permute_5d_blocked(shape, perm, memory_config, dtype, device): assert_with_pcc(torch_output, tt_output, 0.9999) +@skip_for_grayskull("tilize_block gives bad pcc after second iteration") def test_permute_nd(device): torch_tensor = torch.rand((1, 3, 16, 16, 16, 16), dtype=torch.bfloat16) input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) @@ -245,6 +248,8 @@ def test_permute_squeeze(device): @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32]) def test_permute_3D(shape, perm, layout, memory_config, dtype, device): + if is_grayskull() and dtype == ttnn.float32: + pytest.skip("Grayskull doesn't support float32") torch_tensor = torch.rand(shape, dtype=torch.bfloat16) input_tensor = ttnn.from_torch(torch_tensor, layout=layout, device=device, dtype=dtype, memory_config=memory_config) output_tensor = ttnn.permute(input_tensor, perm) From 04c8e848fd49a09a875396baad5857cb3e33078b Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Tue, 10 Dec 2024 20:50:41 +0000 Subject: [PATCH 14/20] #12349: add back tests that now work on blackhole #12550: re-enable some permute tests, disable the ones that aren't working --- .../sweep_tests/pytests/tt_dnn/test_permute.py | 1 - .../python_api_testing/unit_testing/misc/test_transpose.py | 7 +------ tests/ttnn/unit_tests/operations/test_permute.py | 5 ++++- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_permute.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_permute.py index d9ab2571a58..98699e2d4f2 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_permute.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_permute.py @@ -20,7 +20,6 @@ ] -@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("input_shapes, permute_args", params) def test_run_permute_test(input_shapes, permute_args, device, function_level_defaults): datagen_func = [ diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index 3b85e37b15e..4b51e48e761 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -299,7 +299,6 @@ def test_transpose_wh_sharded_program_cache(dtype, device, use_program_cache): ) -@skip_for_blackhole("Mismatching on BH, see #12349") @skip_for_grayskull("Grayskull has pcc issue when transpose used untilize") @pytest.mark.parametrize("n", [1]) @pytest.mark.parametrize("c", [1]) @@ -333,7 +332,6 @@ def test_tranpose_hw_rm_with_padding(device, n, c, h, w): assert_with_pcc(torch_output_tensor, activation_pyt_padded_out, 0.9999) -@skip_for_blackhole("Mismatching on BH, see #12349") @skip_for_grayskull("Grayskull has pcc issue when transpose used untilize") @pytest.mark.parametrize("n", [16]) @pytest.mark.parametrize("c", [128]) @@ -373,7 +371,6 @@ def run_tranpose_hw_rm_program_cache(device, n, c, h, w, use_program_cache): assert_with_pcc(torch_output_tensor, activation_pyt_padded_out, 0.9999) -@skip_for_blackhole("Mismatching on BH, see #12349") @skip_for_grayskull("Grayskull has pcc issue when transpose used untilize") @pytest.mark.parametrize("n", [16]) @pytest.mark.parametrize("c", [128]) @@ -400,7 +397,7 @@ def test_tranpose_hw_rm_with_program_cache(device, n, c, h, w, use_program_cache @pytest.mark.parametrize("c", [224]) @pytest.mark.parametrize("h", [16]) @pytest.mark.parametrize("w", [112]) -def test_tranpose_hw_sharded_rm(device, n, c, h, w): +def test_transpose_hw_sharded_rm(device, n, c, h, w): torch.manual_seed(2005) torch_input_tensor = torch.rand((n, c, h, w), dtype=torch.bfloat16) torch_output_tensor = torch_input_tensor.transpose(2, 3) @@ -467,7 +464,6 @@ def run_tranpose_hw_sharded_rm_with_program_cache(device, n, c, h, w): assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) -@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize("n", [16]) @pytest.mark.parametrize("c", [128]) @pytest.mark.parametrize("h", [128]) @@ -579,7 +575,6 @@ def run_tranpose_hc_sharded(device, n, c, h, w, grid_size): assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) -@skip_for_blackhole("Mismatching on BH, see #12349") @pytest.mark.parametrize( "n, c, h, w, grid_size", [ diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 158ad3515c4..2d919cd2af6 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -10,7 +10,7 @@ import itertools from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import is_blackhole, is_grayskull, skip_for_grayskull +from models.utility_functions import is_blackhole, is_grayskull, skip_for_grayskull, skip_for_blackhole @pytest.mark.parametrize("h", [32]) @@ -185,6 +185,7 @@ def generate_permutations(N): yield perm +@skip_for_blackhole("tilize_block gives bad pcc after second iteration") @skip_for_grayskull("tilize_block gives bad pcc after second iteration") @pytest.mark.parametrize("shape", [(7, 7, 7, 7, 7)]) @pytest.mark.parametrize("perm", generate_permutations(5)) @@ -204,6 +205,7 @@ def test_permute_5d_width(shape, perm, memory_config, dtype, device): assert_with_pcc(torch_output, tt_output, 0.9999) +@skip_for_blackhole("tilize_block gives bad pcc after second iteration") @skip_for_grayskull("tilize_block gives bad pcc after second iteration") @pytest.mark.parametrize("shape", [(3, 65, 3, 3, 65), (1, 6, 256, 20, 50), (6, 20, 50, 1, 256)]) @pytest.mark.parametrize("perm", [(4, 0, 3, 2, 1), (1, 3, 4, 0, 2), (3, 0, 4, 1, 2)]) @@ -225,6 +227,7 @@ def test_permute_5d_blocked(shape, perm, memory_config, dtype, device): assert_with_pcc(torch_output, tt_output, 0.9999) +@skip_for_blackhole("tilize_block gives bad pcc after second iteration") @skip_for_grayskull("tilize_block gives bad pcc after second iteration") def test_permute_nd(device): torch_tensor = torch.rand((1, 3, 16, 16, 16, 16), dtype=torch.bfloat16) From 5e23f4bad80d35aa326db0026dc8d5fbd0dcf2f5 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Wed, 11 Dec 2024 21:50:09 +0000 Subject: [PATCH 15/20] #0: rename transpose compute kernel to xw from xh --- .../python_api_testing/unit_testing/misc/test_transpose.py | 2 +- tests/ttnn/unit_tests/operations/test_permute.py | 2 +- ...ingle_tile_size.cpp => transpose_xw_rm_single_tile_size.cpp} | 0 .../data_movement/permute/device/permute_program_factory.cpp | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/{transpose_xh_rm_single_tile_size.cpp => transpose_xw_rm_single_tile_size.cpp} (100%) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index 4b51e48e761..6817fc1bd82 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -379,7 +379,7 @@ def run_tranpose_hw_rm_program_cache(device, n, c, h, w, use_program_cache): def test_tranpose_hw_rm_with_program_cache(device, n, c, h, w, use_program_cache): for _ in range(2): run_tranpose_hw_rm_program_cache(device, n, c, h, w, use_program_cache) - # # dummy tensor to change tensor alloc + # dummy tensor to change tensor alloc dummy_shape = [1, 1, 32, 32] py_dummy_tensor = torch.randn(dummy_shape) tt_dummy_tensor = ttnn.from_torch( diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 2d919cd2af6..70e0fb4394d 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -247,7 +247,7 @@ def test_permute_squeeze(device): @pytest.mark.parametrize("shape", [(1, 49, 768)]) @pytest.mark.parametrize("perm", generate_permutations(3)) -@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT]) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.float32]) def test_permute_3D(shape, perm, layout, memory_config, dtype, device): diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xh_rm_single_tile_size.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp similarity index 100% rename from ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xh_rm_single_tile_size.cpp rename to ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 5ead48cb7cb..1a7a735d7e5 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -290,7 +290,7 @@ PermuteDeviceOperation::MultiCoreBlockedGeneric::create( bool fp32_dest_acc_en = cb_data_format_output == tt::DataFormat::Float32; auto compute_kernel_id = tt::tt_metal::CreateKernel( program, - "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xh_rm_single_tile_size.cpp", + "ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp", all_cores, tt::tt_metal::ComputeConfig{ .fp32_dest_acc_en = fp32_dest_acc_en, From 8ac1b8dd9d663052b5dfe0ccf211718fa9c46d2e Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Thu, 12 Dec 2024 20:48:14 +0000 Subject: [PATCH 16/20] #0: fix uninit issue --- .../unit_testing/misc/test_transpose.py | 2 - .../unit_tests/operations/test_permute.py | 10 +++ .../transpose_xw_rm_single_tile_size.cpp | 7 +-- .../device/permute_device_operation.hpp | 4 +- .../device/permute_program_factory.cpp | 14 ++--- .../data_movement/permute/permute.cpp | 4 +- .../transpose/device/transpose_op.cpp | 62 ++++++++++++------- 7 files changed, 59 insertions(+), 44 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py index 6817fc1bd82..a8f8385c059 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_transpose.py @@ -1037,8 +1037,6 @@ def test_tranpose_hw_sharded_tiled_8_cores(device, n, c, h, w): tt_input_tensor = ttnn.to_memory_config(tt_input_tensor, sharded_mem_config) tt_output_tensor = ttnn.transpose(tt_input_tensor, 2, 3, memory_config=sharded_mem_config) - tt_output_tensor = ttnn.to_memory_config(tt_output_tensor, ttnn.L1_MEMORY_CONFIG) - tt_output_tensor = ttnn.from_device(tt_output_tensor) tt_output_tensor = ttnn.to_torch(tt_output_tensor) assert_with_pcc(torch_output_tensor, tt_output_tensor, 0.9999) diff --git a/tests/ttnn/unit_tests/operations/test_permute.py b/tests/ttnn/unit_tests/operations/test_permute.py index 70e0fb4394d..cc09f7d7d5e 100644 --- a/tests/ttnn/unit_tests/operations/test_permute.py +++ b/tests/ttnn/unit_tests/operations/test_permute.py @@ -260,3 +260,13 @@ def test_permute_3D(shape, perm, layout, memory_config, dtype, device): torch_output = torch.permute(torch_tensor, perm) assert torch_output.shape == output_tensor.shape assert_with_pcc(torch_output, output_tensor, 0.9999) + + +def test_nil_volume_permute(device): + torch_tensor = torch.rand([1, 0, 30, 32], dtype=torch.bfloat16) + input_tensor = ttnn.from_torch(torch_tensor, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.permute(input_tensor, (0, 1, 3, 2)) + output_tensor = ttnn.to_torch(output_tensor) + torch_output = torch.permute(torch_tensor, (0, 1, 3, 2)) + assert torch_output.shape == output_tensor.shape + assert_with_pcc(torch_output, output_tensor, 0.9999) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp index 0b69361f57a..41151070064 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/compute/transpose_xw_rm_single_tile_size.cpp @@ -24,17 +24,12 @@ void MAIN { unary_op_init_common(cb_in, cb_out); for (uint32_t n = 0; n < num_blocks; n++) { - // have to global init here, otherwise pcc is bad - // if n > 0, then some register isn't cleared and the output of tilize_block is garbage - unary_op_init_common(cb_in, cb_out); // tilize input via unpack and then pack tilize_init_short(cb_in, 1); cb_wait_front(cb_in, x_block_size); - // results are correct according to unpacker here cb_reserve_back(cb_tilize, 1); - // removing this line causes the output of tilize_block to be garbage in the second iteration tilize_block(cb_in, 1, cb_tilize); // tilize and pack into cb_tilize // tile slice according to unpacker is garbage after tilize_block in the second iteration, missing an uninit? @@ -62,7 +57,7 @@ void MAIN { cb_push_back(cb_out, w_block_size); cb_wait_front(cb_out, w_block_size); - pack_untilize_uninit(); + pack_untilize_uninit(cb_out); cb_pop_front(cb_tilize, 1); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp index e27b8251bdc..05e251e8ca8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.hpp @@ -35,7 +35,7 @@ struct PermuteDeviceOperation { struct shared_variables_t { KernelHandle unary_reader_kernel_id; KernelHandle unary_writer_kernel_id; - CoreRangeSet all_cores; + CoreRangeSet core_range; }; using cached_program_t = ttnn::device_operation::CachedProgram; @@ -57,7 +57,7 @@ struct PermuteDeviceOperation { KernelHandle unary_reader_kernel_id; KernelHandle unary_writer_kernel_id; KernelHandle compute_kernel_id; - CoreRangeSet all_cores; + CoreRangeSet core_range; }; using cached_program_t = ttnn::device_operation::CachedProgram; diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp index 1a7a735d7e5..56bfe893f5d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/permute_program_factory.cpp @@ -10,14 +10,14 @@ namespace ttnn::operations::data_movement { namespace detail { uint32_t num_pages(const ttnn::Tensor& input_tensor) { - const auto& padded_shape = input_tensor.get_logical_shape(); - return padded_shape.volume() / padded_shape[-1]; + const auto& shape = input_tensor.get_logical_shape(); + return shape.volume() / shape[-1]; } uint32_t page_size(const ttnn::Tensor& input_tensor) { auto BUFFER_ALIGNMENT = input_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? DRAM_ALIGNMENT : L1_ALIGNMENT; - const auto& padded_shape = input_tensor.get_logical_shape(); // in anticipation of RM padding - return tt::round_up(padded_shape[-1] * input_tensor.element_size(), BUFFER_ALIGNMENT); + const auto& shape = input_tensor.get_logical_shape(); // in anticipation of RM padding + return tt::round_up(shape[-1] * input_tensor.element_size(), BUFFER_ALIGNMENT); } std::vector get_row_strides(const ttnn::SimpleShape& shape) { @@ -145,7 +145,7 @@ void PermuteDeviceOperation::MultiCoreRowInvariant::override_runtime_arguments( auto src_buffer = input_tensor.buffer(); auto dst_buffer = output_tensor.buffer(); - auto& all_cores = cached_program.shared_variables.all_cores; + auto& all_cores = cached_program.shared_variables.core_range; auto cores = corerange_to_cores(all_cores, std::nullopt); for (const auto& core : cores) { @@ -341,7 +341,7 @@ PermuteDeviceOperation::MultiCoreBlockedGeneric::create( {.unary_reader_kernel_id = unary_reader_kernel_id, .unary_writer_kernel_id = unary_writer_kernel_id, .compute_kernel_id = compute_kernel_id, - .all_cores = all_cores}}; + .core_range = all_cores}}; } void PermuteDeviceOperation::MultiCoreBlockedGeneric::override_runtime_arguments( @@ -359,7 +359,7 @@ void PermuteDeviceOperation::MultiCoreBlockedGeneric::override_runtime_arguments auto src_buffer = input_tensor.buffer(); auto dst_buffer = output_tensor.buffer(); - auto& all_cores = cached_program.shared_variables.all_cores; + auto& all_cores = cached_program.shared_variables.core_range; auto cores = corerange_to_cores(all_cores, std::nullopt); for (const auto& core : cores) { diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp index 98768310aa2..f0f33a99040 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp @@ -159,7 +159,7 @@ ttnn::Tensor permute_launch( } bool is_permute_nop(const ttnn::Tensor& a, tt::stl::Span dims) { - if (a.get_shape().rank() == 1) { + if (a.get_shape().rank() <= 1) { return true; } auto normalized_dims = ttnn::SmallVector(dims.begin(), dims.end()); @@ -190,8 +190,6 @@ ttnn::Tensor ExecutePermute::invoke( return ttnn::to_memory_config(input_tensor, memory_config.value_or(input_tensor.memory_config())); } - auto padded_shape = input_tensor.get_padded_shape(); - const auto input_layout = input_tensor.get_layout(); auto adjust_order = [](tt::stl::Span dims) { ttnn::SmallVector new_order; diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp index 05bc7b3f6cf..c560cd17ebe 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_op.cpp @@ -23,14 +23,29 @@ void Transpose::validate(const std::vector& input_tensors) const { TT_FATAL(input_tensor.buffer() != nullptr, "Operands to transpose need to be allocated in buffers on device!"); TT_FATAL( !(this->dim != TransposeOpDim::HC && this->pad_value.has_value() && this->pad_value != 0.0f), - "Non-zero padding is not supported for any transpose other than HC."); + "Non-zero padding {} is not supported for any transpose other than HC.", + this->pad_value.value()); + TT_FATAL( + this->dim == TransposeOpDim::HC || this->dim == TransposeOpDim::WH || this->dim == TransposeOpDim::CN, + "Transpose HC, WH, CN are the only supported transpose operations. Transpose {} is not supported.", + (int)this->dim); const auto shape = input_tensor.get_padded_shape(); bool row_major = input_tensor.get_layout() == Layout::ROW_MAJOR; uint32_t W = shape[3], H = shape[2], C = shape[1], N = shape[0]; uint32_t HW = H * W; if (not row_major) { - TT_FATAL(W % TILE_WIDTH == 0 && H % TILE_HEIGHT == 0, "Error"); - TT_FATAL(input_tensor.volume() % TILE_HW == 0, "Error"); + TT_FATAL( + W % TILE_WIDTH == 0 && H % TILE_HEIGHT == 0, + "Tiled tensor H {} W {} must be a multiple of TILE HEIGHT {} and TILE WIDTH", + H, + W, + TILE_HEIGHT, + TILE_WIDTH); + TT_FATAL( + input_tensor.volume() % TILE_HW == 0, + "Tiled tensor volume {} must be a multiple of TILE HEIGHT * TILE WIDTH", + input_tensor.volume(), + TILE_HW); } uint32_t ROW_MAJOR_STICK_WIDTH = 16; if (this->dim == TransposeOpDim::WH) { @@ -38,19 +53,28 @@ void Transpose::validate(const std::vector& input_tensors) const { TT_FATAL( (W * input_tensor.element_size()) % ROW_MAJOR_STICK_WIDTH == 0 && (H * input_tensor.element_size()) % ROW_MAJOR_STICK_WIDTH == 0, - "Error"); + "Row major tensor W {} H {} must be a multiple of ROW_MAJOR_STICK_WIDTH for transpose wh", + W, + H, + ROW_MAJOR_STICK_WIDTH); } if (input_tensor.is_sharded()) { - TT_FATAL(input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Error"); + TT_FATAL( + input_tensor.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, + "Only height and block sharding is supported for transpose wh"); const auto shard_spec = input_tensor.shard_spec().value(); TT_FATAL( (shard_spec.shape[0] % H == 0) || (H % shard_spec.shape[0] == 0), - "Only a multiple of H or a factor of H is allows for the shard height"); + "Only a multiple of H {} or a factor of H is allows for the shard height {} for transpose WH", + H, + shard_spec.shape[0]); TT_FATAL(shard_spec.shape[1] == W, "Only height sharding is supported"); - TT_FATAL(this->output_mem_config.is_sharded(), "Error"); - TT_FATAL(this->output_mem_config.memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Error"); + TT_FATAL(this->output_mem_config.is_sharded(), "Output must be sharded for transpose WH"); + TT_FATAL( + this->output_mem_config.memory_layout != TensorMemoryLayout::WIDTH_SHARDED, + "Only height and block sharding is supported for transpose wh"); } else { - TT_FATAL(!this->output_mem_config.is_sharded(), "Interleaved inputs cannot output sharded outputs"); + TT_FATAL(!this->output_mem_config.is_sharded(), "Interleaved input tensors cannot output sharded outputs"); } } else { if (input_tensor.is_sharded()) { @@ -59,10 +83,11 @@ void Transpose::validate(const std::vector& input_tensors) const { "Only height sharding is supported for transpose hc"); const auto shard_spec = input_tensor.shard_spec().value(); TT_FATAL(shard_spec.shape[1] == W, "Block/Width sharding is not supported"); - TT_FATAL(this->output_mem_config.is_sharded(), "Sharded input can only output sharded tensors"); + TT_FATAL( + this->output_mem_config.is_sharded(), "Sharded input can only output sharded tensors for transpose hc"); TT_FATAL( this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, - "Only height sharding is supported"); + "Only height sharding is supported for the ouput of sharded transpose hc"); } else { TT_FATAL(!this->output_mem_config.is_sharded(), "Interleaved inputs cannot output sharded outputs"); } @@ -84,19 +109,8 @@ void Transpose::validate(const std::vector& input_tensors) const { "HC transpose does not support sharded+tilized inputs"); TT_FATAL( !(input_tensor.is_sharded() && pad_value.has_value() && pad_value.value() != 0.0f), - "Sharded HC transpose does not support non-zero padding"); - } else if (this->dim == TransposeOpDim::CW) { - TT_FATAL(C % TILE_WIDTH == 0, "Error"); - TT_FATAL( - input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32, "Error"); - } else if (this->dim == TransposeOpDim::NH) { - TT_FATAL(N % TILE_HEIGHT == 0, "Error"); - TT_FATAL( - input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32, "Error"); - } else if (this->dim == TransposeOpDim::NW) { - TT_FATAL(N % TILE_WIDTH == 0, "Error"); - TT_FATAL( - input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::FLOAT32, "Error"); + "Sharded HC transpose does not support non-zero padding {}", + pad_value.value()); } } From d68e2ac8b7a5a1bf37809c5a28c47e934760a838 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Thu, 12 Dec 2024 22:33:20 +0000 Subject: [PATCH 17/20] #0: add attn matmul failure print --- .../attn_matmul/device/attn_matmul_device_operation.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/device/attn_matmul_device_operation.cpp b/ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/device/attn_matmul_device_operation.cpp index 26de812b441..092481bdaee 100644 --- a/ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/device/attn_matmul_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/matmul/attn_matmul/device/attn_matmul_device_operation.cpp @@ -60,7 +60,9 @@ void AttnMatmulDeviceOperation::validate(const std::vector& input_tensor } else { TT_FATAL( ashape[3] == bshape[2], - "Dimension K (A.shape[3] and B.shape[2]) must match for A and B in attn_matmul op"); // A.K == B.K + "Dimension K (A.shape[3]and B.shape[2]) must match for A shape: {} and B shape: {} in attn_matmul op", + ashape, + bshape); // A.K == B.K } } From 9f670c5f1a2603db1df4aaf6f7a433d71377e91d Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Mon, 16 Dec 2024 16:07:03 +0000 Subject: [PATCH 18/20] #0: switch from Span to SmallVector to stop UB --- .../data_movement/permute/permute.cpp | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp index f0f33a99040..00d622a15fb 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/permute.cpp @@ -27,19 +27,13 @@ inline bool is_on_device(const Tensor& t) { ttnn::Tensor permute_impl( const ttnn::Tensor& a, - const tt::stl::Span& dims, + const ttnn::SmallVector& dims, const MemoryConfig& output_mem_config, const std::optional& pad_value) { using ttnn::operations::experimental::auto_format::AutoFormat; - Device* device; // Get the device - if (a.storage_type() != StorageType::DEVICE) { - device = AutoFormat::GetDefaultDevice(); - TT_ASSERT(device != nullptr, "Requires setting default device if no inputs to op are on device"); - } else { - device = a.device(); - } + Device* device = a.device(); if (a.get_shape().rank() > 4) { auto input = a.get_layout() == Layout::TILE @@ -56,9 +50,6 @@ ttnn::Tensor permute_impl( TT_FATAL(dims.size() == 4, "Only 4D tensor are supported for permute."); uint32_t N = dims[0], C = dims[1], H = dims[2], W = dims[3]; - // Convert tensor back to original - auto input_shape = a.get_logical_shape(); - auto formatted_input_tensor = a; // WH and CN should be supported without typecast bool wh = N == 0 && C == 1 && H == 3 && W == 2; @@ -134,13 +125,14 @@ ttnn::Tensor permute_impl( } else { TT_ASSERT(false, "Illegal permute args"); } + // Convert tensor back to original dtype if typecast was performed output = typecast ? ttnn::typecast(output, DataType::BFLOAT8_B) : output; return output; } ttnn::Tensor permute_launch( const ttnn::Tensor& a, - tt::stl::Span dims, + const ttnn::SmallVector& dims, const MemoryConfig& output_mem_config, const std::optional& pad_value) { std::vector output_tensors = {ttnn::Tensor(operation::get_workers_for_op_output({a}))}; @@ -190,7 +182,6 @@ ttnn::Tensor ExecutePermute::invoke( return ttnn::to_memory_config(input_tensor, memory_config.value_or(input_tensor.memory_config())); } - const auto input_layout = input_tensor.get_layout(); auto adjust_order = [](tt::stl::Span dims) { ttnn::SmallVector new_order; TT_FATAL(dims.size() <= 4, "Minimum rank of tensor required is 4"); @@ -206,6 +197,7 @@ ttnn::Tensor ExecutePermute::invoke( auto itensor = (input_tensor.get_logical_shape().rank() < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor; auto iorder = normalized_dims.size() < 4 ? adjust_order(normalized_dims) : normalized_dims; + const auto input_layout = input_tensor.get_layout(); auto output_tensor = detail::permute_launch(itensor, iorder, memory_config.value_or(input_tensor.memory_config()), pad_value); output_tensor = ttnn::to_layout(output_tensor, input_layout, std::nullopt, std::nullopt, (Device*)nullptr); From 0401154ddee967d8382a86062e77aef6fe63fcda Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Mon, 16 Dec 2024 20:14:21 +0000 Subject: [PATCH 19/20] #0: skip test uniform due to #16066 --- tests/ttnn/unit_tests/operations/test_uniform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ttnn/unit_tests/operations/test_uniform.py b/tests/ttnn/unit_tests/operations/test_uniform.py index 9c3f05a6a6a..abdfd9aaa31 100644 --- a/tests/ttnn/unit_tests/operations/test_uniform.py +++ b/tests/ttnn/unit_tests/operations/test_uniform.py @@ -94,6 +94,7 @@ def run_uniform(shape, rand_range, dtype, device, compute_kernel_options=None, m ) +@pytest.mark.skip("#16066: Undefined behaviour. It will fail on some runs and pass on others since it's stochastic.") @skip_for_grayskull("Requires wormhole_b0 to run") @pytest.mark.parametrize( "shape", From 01fa552f17b6e473d0ae36332f2620ec3e529da9 Mon Sep 17 00:00:00 2001 From: Saad Jameel Date: Mon, 16 Dec 2024 21:11:22 +0000 Subject: [PATCH 20/20] #0: add comments --- .../dataflow/reader_permute_interleaved_rm_blocked_generic.cpp | 1 + .../dataflow/writer_permute_interleaved_rm_blocked_generic.cpp | 1 + .../dataflow/writer_permute_interleaved_rm_row_invariant.cpp | 1 + 3 files changed, 3 insertions(+) diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp index 15b8b199c3f..f63aaab6d09 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/reader_permute_interleaved_rm_blocked_generic.cpp @@ -27,6 +27,7 @@ void kernel_main() { uint32_t end_block = get_arg_val(2); // Input shape and strides (excluding W dimension and measured in rows, not bytes) + // start at runtime arg 3 since address/start_block/end_block make up the first 3 args uint32_t input_shape[N], src_strides[N]; for (uint32_t i = 3; i < N + 3; i++) { input_shape[i - 3] = get_arg_val(i); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp index 910b23b72d5..5af2edb379f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_blocked_generic.cpp @@ -49,6 +49,7 @@ void kernel_main() { const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = output_tensor_page_size}; // Input shape, permutation, and destination strides + // start at runtime arg 3 since address/start_block/end_block make up the first 3 args uint32_t input_shape[N], perm[N], dest_strides[N]; for (uint32_t i = 3; i < N + 3; i++) { input_shape[i - 3] = get_arg_val(i); diff --git a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp index 46903375ff6..a06e5d56892 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/permute/device/kernels/dataflow/writer_permute_interleaved_rm_row_invariant.cpp @@ -17,6 +17,7 @@ void kernel_main() { const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = page_size}; + // start at runtime arg 3 since address/start_block/end_block make up the first 3 args uint32_t input_shape[N], perm[N], dest_strides[N]; for (uint32_t i = 3; i < N + 3; i++) { input_shape[i - 3] = get_arg_val(i);