From b82e38d4034f224f313551d1265f3bd8bd841aba Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Tue, 19 Nov 2024 09:52:43 +0000 Subject: [PATCH 01/11] Remove restriction of input_nsticks_per_core % w == 0 for height sharded tensor inputs. Signed-off-by: Nilaykumar Patel --- .../unit_tests/operations/test_upsample.py | 25 ++++-- .../writer_upsample_multi_core_sharded.cpp | 75 ++++++++-------- .../upsample_program_factory_multicore.cpp | 86 ++++++++++++++++--- 3 files changed, 135 insertions(+), 51 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index 86047a86581..fa57a486650 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -109,16 +109,30 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): [1, 64, 132, 10], [1, 32, 8, 8], [2, 640, 32, 32], + # some random shapes + [1, 32, 5, 4], + [3, 32, 4, 4], + [5, 64, 5, 5], + [1, 128, 5, 8], + [1, 32, 5, 4], + [7, 64, 128, 17], + [3, 64, 132, 19], ], ) -@pytest.mark.parametrize("scale_h", [2]) -@pytest.mark.parametrize("scale_w", [2]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize("scale_h", [2, 3]) +@pytest.mark.parametrize("scale_w", [2, 3]) @pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK]) def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy): ## input shape is N C H W batch_size, num_channels, height, width = input_shape torch.manual_seed(0) input = torch.rand(input_shape, dtype=torch.bfloat16) + # for i in range(input_shape[0]): + # for j in range(input_shape[1]): + # for k in range(input_shape[2]): + # for l in range(input_shape[3]): + # input[i, j, k, l] = k * width + l + 1 ## golden reference using torch scale_factor = (scale_h, scale_w) @@ -136,15 +150,15 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate max_grid_size = (device_grid.y, device_grid.x) if shard_strategy == ttnn.ShardStrategy.HEIGHT: ## nsticks per shard should be divisible by in_w - max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1]) + max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1]) nshards = max_nshards while nshards > 0: - if batch_size * height % nshards == 0: + if batch_size * height * width % nshards == 0: break nshards -= 1 ncores = nshards elif shard_strategy == ttnn.ShardStrategy.BLOCK: - max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW + max_nshards_h = min(batch_size * height * width, max_grid_size[0]) ## height along NHW max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C ## find nshards_h along NHW nshards_h = max_nshards_h @@ -353,6 +367,7 @@ def test_bilinear_multi_core( ## compare the results torch_result = torch_result.permute(0, 2, 3, 1) + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999) allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1) logger.info(pcc_msg) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp index 03530ea7433..91e9a6ff9a2 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp @@ -3,12 +3,28 @@ // SPDX-License-Identifier: Apache-2.0 #include + #include "dataflow_api.h" +#define ENABLE_DEBUG_PRINT 0 -void kernel_main() { +#if ENABLE_DEBUG_PRINT == 1 +#include "debug/dprint.h" + +inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { + volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; + for (uint32_t page = 0; page < npages; ++page) { + DPRINT << start + page << ": "; + for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { + DPRINT << BF16(*ptr) << " "; + } + DPRINT << ENDL(); + } +} +#endif +void kernel_main() { uint32_t stick_nbytes = get_arg_val(0); - uint32_t in_image_rows_per_core = get_arg_val(1); + uint32_t in_nsticks_per_core = get_arg_val(1); uint32_t scale_h = get_arg_val(2); uint32_t scale_w = get_arg_val(3); uint32_t in_w = get_arg_val(4); @@ -17,46 +33,37 @@ void kernel_main() { constexpr uint32_t in_cb_id = get_compile_time_arg_val(0); constexpr uint32_t out_cb_id = get_compile_time_arg_val(1); constexpr uint32_t is_reader = get_compile_time_arg_val(2); + constexpr uint32_t config_cb_id = get_compile_time_arg_val(3); + + uint32_t reader_nsticks_per_core = (in_nsticks_per_core + is_reader) / 2; + uint32_t writer_nsticks_per_core = in_nsticks_per_core / 2; + uint32_t image_row_begin = is_reader ? 0 : reader_nsticks_per_core; + uint32_t image_row_end = is_reader ? reader_nsticks_per_core : in_nsticks_per_core; + uint32_t l1_read_addr = get_read_ptr(in_cb_id); + uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * scale_w * stick_nbytes; - uint32_t in_image_row_nbytes = in_w * stick_nbytes; - uint32_t out_image_row_nbytes = out_w * stick_nbytes; - uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2; - uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2; - uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core; - uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core; - uint32_t l1_read_addr = get_read_ptr(in_cb_id) + image_row_begin * in_image_row_nbytes; - uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * out_image_row_nbytes; + uint32_t config_l1_addr = get_read_ptr(config_cb_id); + volatile tt_l1_ptr uint16_t* config_data = reinterpret_cast(config_l1_addr); + uint32_t reader_idx = 0; + if (!is_reader) { + reader_idx = 4 * (scale_h * image_row_begin); + } cb_reserve_back(out_cb_id, out_w); - // assuming shard begins with a new row. TODO: generalize? - for (uint32_t image_row = image_row_begin; image_row < image_row_end; ++image_row) { - uint32_t l1_write_addr_image_row_start = l1_write_addr; - for (uint32_t i = 0; i < in_w; ++i) { + for (uint32_t row_begin = image_row_begin; row_begin < image_row_end; ++row_begin) { + for (uint32_t sh = 0; sh < scale_h; sh++) { + uint16_t corex = config_data[reader_idx++]; + uint16_t corey = config_data[reader_idx++]; + uint16_t offset = config_data[reader_idx++]; + reader_idx++; + uint64_t src_remote_addr = get_noc_addr(corex, corey, l1_read_addr + offset * stick_nbytes); // replicate stick scale_w times. - for (uint32_t sw = 0; sw < scale_w; ++sw) { - // replicate stick scale_w times. - if constexpr (is_reader) { - uint64_t src_noc_addr = get_noc_addr(l1_read_addr); - noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); - } else { - uint64_t dst_noc_addr = get_noc_addr(l1_write_addr); - noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes); - } + for (uint32_t sw = 0; sw < scale_w; sw++) { + noc_async_read(src_remote_addr, l1_write_addr, stick_nbytes); l1_write_addr += stick_nbytes; } - l1_read_addr += stick_nbytes; - } - - // Duplicate the whole image row in one shot - if constexpr (is_reader) { - uint64_t src_noc_addr = get_noc_addr(l1_write_addr_image_row_start); - noc_async_read(src_noc_addr, l1_write_addr, out_image_row_nbytes); - } else { - uint64_t dst_noc_addr = get_noc_addr(l1_write_addr); - noc_async_write(l1_write_addr_image_row_start, dst_noc_addr, out_image_row_nbytes); } - l1_write_addr += out_image_row_nbytes; } cb_push_back(out_cb_id, out_w); diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 0e12adcb29a..4462244149d 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -3,16 +3,15 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include -#include "upsample_op.hpp" -#include "ttnn/operations/math.hpp" +#include "buffers/buffer_constants.hpp" +#include "common/core_coord.hpp" +#include "ttnn/tensor/host_buffer/functions.hpp" #include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" #include "tt_metal/common/math.hpp" -#include "tt_metal/tt_stl/reflection.hpp" using namespace tt::constants; using namespace tt::tt_metal; @@ -20,6 +19,46 @@ using namespace tt::tt_metal; namespace ttnn::operations::upsample { using namespace tt; +Tensor create_config_tensor( + Device *device, + ShardSpec &input_shard_spec, + const uint32_t batch_size, + const uint32_t in_h, + const uint32_t in_w, + const uint32_t scale_factor_h, + const uint32_t scale_factor_w, + const uint32_t ncores) { + std::vector config_vector; + uint32_t input_nsticks_per_core = input_shard_spec.shape[0]; + uint32_t ncores_x = device->compute_with_storage_grid_size().x; + uint32_t in_core = 0; + uint32_t w = 0; + uint32_t curr_stick = 0; + auto core_coords = device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); + for (uint32_t b = 0; b < batch_size; b++) { + for (uint32_t h = 0; h < in_h; h++) { + for (uint32_t w = 0; w < in_w; w++) { + if (curr_stick == input_nsticks_per_core) { + curr_stick = 0; + in_core++; + core_coords = + device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); + } + config_vector.insert(config_vector.end(), {core_coords.x, core_coords.y, curr_stick, 0}); + curr_stick++; + } + for (uint32_t j = 0; j < scale_factor_h - 1; j++) + config_vector.insert(config_vector.end(), config_vector.end() - (4 * in_w), config_vector.end()); + } + } + + uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; + Shape config_shape = Shape({config_vector.size() / elems_per_core, elems_per_core}); + auto config_buffer = owned_buffer::create(std::move(config_vector)); + Tensor config_tensor = Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); + return config_tensor; +} + operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { Program program = CreateProgram(); Device *device = input.device(); @@ -54,7 +93,6 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // extra limitation to avoid post upsample step of resharding if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - TT_FATAL(in_nsticks_per_core % in_w == 0, "Restriction: Input sticks per core {} should be divisible by input width {}. TODO to remove this restriction", in_nsticks_per_core, in_w); } else if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { ncores_x = all_cores.ranges().begin()->end_coord.x + 1; ncores_nhw = all_cores.ranges().begin()->end_coord.y + 1; @@ -69,8 +107,6 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // TODO: Support non-multiple case TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core); - TT_FATAL(out_nsticks_per_core == output_nsticks_per_core, "Output sticks per shard {} should be same as output sticks per core {}", out_nsticks_per_core, output_nsticks_per_core); - TT_FATAL(input_nsticks_per_core % in_w == 0, "Error"); // CBs @@ -106,12 +142,37 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& log_debug(LogOp, "ncores: {}, ncores_x: {}", ncores, ncores_x); log_debug(LogOp, "input_nsticks_per_core: {}, output_nsticks_per_core: {}", input_nsticks_per_core, output_nsticks_per_core); + // create config tensor + Tensor config_tensor = create_config_tensor( + device, + shard_spec, + input.legacy_shape()[0], + input.legacy_shape()[1], + in_w, + scale_factor_h, + scale_factor_w, + ncores); + auto shard_shape = std::array({1, (uint32_t)config_tensor.get_shape()[-1]}); + ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, ShardOrientation::ROW_MAJOR, false); + MemoryConfig memory_config{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1_SMALL, config_shard_spec}; + auto config_tensor_device = config_tensor.to(device, memory_config); + tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer()); + + tt::DataFormat config_df = tt::DataFormat::RawUInt16; + Buffer *config_buffer = config_tensor_device.buffer(); + uint32_t config_cb_id = tt::CB::c_in2; + auto config_cb_config = CircularBufferConfig(config_buffer->size(), {{config_cb_id, config_df}}) + .set_page_size(config_cb_id, config_buffer->page_size()) + .set_globally_allocated_address(*config_buffer); + CBHandle config_cb = CreateCircularBuffer(program, all_cores, config_cb_config); + // Kernels std::vector writer_compile_time_args = { in_cb_id, out_cb_id, false, + config_cb_id, }; auto writer_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); auto writer_kernel = @@ -121,6 +182,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& in_cb_id, out_cb_id, true, + config_cb_id, }; auto reader_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); auto reader_kernel = @@ -133,11 +195,11 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& uint32_t writer_nargs = 7; std::vector writer_rt_args(writer_nargs); writer_rt_args[0] = input_stick_nbytes; - writer_rt_args[1] = input_nsticks_per_core / in_w; + writer_rt_args[1] = input_nsticks_per_core; writer_rt_args[2] = scale_factor_h; writer_rt_args[3] = scale_factor_w; - writer_rt_args[4] = in_w; - writer_rt_args[5] = out_w; + writer_rt_args[4] = input_nsticks_per_core; + writer_rt_args[5] = output_nsticks_per_core / 2; // half of the outputs are processed by each core writer_rt_args[6] = 0; // set for each core below uint32_t start_input_stick_id = 0; @@ -163,7 +225,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& TT_THROW("Unsupported memory layout"); } - auto override_runtime_args_callback = [writer_kernel, cb_src0, out_cb]( + auto override_runtime_args_callback = [writer_kernel, cb_src0, config_cb, out_cb]( const void* operation, Program &program, const std::vector& input_tensors, From 8890e704ac66b487f923e48498832590ebdcad9f Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Tue, 26 Nov 2024 11:03:27 +0000 Subject: [PATCH 02/11] Add support for block sharding for upsample. Signed-off-by: Nilaykumar Patel --- .../writer_upsample_multi_core_sharded.cpp | 1 - .../upsample_program_factory_multicore.cpp | 55 +++++++++++++++---- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp index 91e9a6ff9a2..0fbee10ad5a 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp @@ -68,6 +68,5 @@ void kernel_main() { cb_push_back(out_cb_id, out_w); - noc_async_write_barrier(); noc_async_read_barrier(); } diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 4462244149d..f0bcf187d85 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include #include #include "buffers/buffer_constants.hpp" @@ -27,14 +28,21 @@ Tensor create_config_tensor( const uint32_t in_w, const uint32_t scale_factor_h, const uint32_t scale_factor_w, - const uint32_t ncores) { + TensorMemoryLayout shard_scheme, + uint32_t ncores_nhw, + uint32_t ncores_x) { std::vector config_vector; uint32_t input_nsticks_per_core = input_shard_spec.shape[0]; - uint32_t ncores_x = device->compute_with_storage_grid_size().x; uint32_t in_core = 0; uint32_t w = 0; uint32_t curr_stick = 0; - auto core_coords = device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); + if(shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { + ncores_x = 1; + ncores_nhw = 1; + } + uint32_t physical_core_x = device->compute_with_storage_grid_size().x; + + auto core_coords = device->worker_core_from_logical_core(CoreCoord(in_core % physical_core_x, in_core / physical_core_x)); for (uint32_t b = 0; b < batch_size; b++) { for (uint32_t h = 0; h < in_h; h++) { for (uint32_t w = 0; w < in_w; w++) { @@ -42,7 +50,7 @@ Tensor create_config_tensor( curr_stick = 0; in_core++; core_coords = - device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); + device->worker_core_from_logical_core(CoreCoord(0, in_core)); } config_vector.insert(config_vector.end(), {core_coords.x, core_coords.y, curr_stick, 0}); curr_stick++; @@ -51,10 +59,31 @@ Tensor create_config_tensor( config_vector.insert(config_vector.end(), config_vector.end() - (4 * in_w), config_vector.end()); } } + // Copy for y direction + std::vector temp_config_vector; + /*auto prev_idx = 0;*/ + /*auto idx = 0;*/ + /*for(uint32_t i = 0; i < ncores_nhw; i++) {*/ + /* idx = 4 * (i+1) * input_nsticks_per_core * scale_factor_h;*/ + /* for(uint32_t j = 0; j < ncores_x; j++) {*/ + /* temp_config_vector.insert(temp_config_vector.end(), config_vector.begin() + prev_idx, config_vector.begin() + idx);*/ + /* }*/ + /* prev_idx = idx;*/ + /*}*/ + for(uint32_t i = 0; i < ncores_x; i++) { + /*TODO: Change take core x into considereation.*/ + temp_config_vector.insert(temp_config_vector.end(), config_vector.begin(), config_vector.end()); + } + + using namespace std; + uint32_t core = 0; + for(auto i = 0; i < temp_config_vector.size(); i+=4) { + cout << temp_config_vector[i] << " " << temp_config_vector[i+1] << " " << temp_config_vector[i+2] << " " << temp_config_vector[i+3] << endl; + } uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; - Shape config_shape = Shape({config_vector.size() / elems_per_core, elems_per_core}); - auto config_buffer = owned_buffer::create(std::move(config_vector)); + Shape config_shape = Shape({temp_config_vector.size() / elems_per_core, elems_per_core}); + auto config_buffer = owned_buffer::create(std::move(temp_config_vector)); Tensor config_tensor = Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); return config_tensor; } @@ -151,17 +180,23 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& in_w, scale_factor_h, scale_factor_w, - ncores); + input.memory_config().memory_layout, + ncores_nhw, + ncores_x); + config_tensor.print(); auto shard_shape = std::array({1, (uint32_t)config_tensor.get_shape()[-1]}); - ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, ShardOrientation::ROW_MAJOR, false); - MemoryConfig memory_config{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1_SMALL, config_shard_spec}; + auto config_tensor_shard_orientation = input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED ? (shard_spec.orientation == ShardOrientation::COL_MAJOR ? ShardOrientation::ROW_MAJOR : ShardOrientation::COL_MAJOR) : ShardOrientation::ROW_MAJOR; + ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, config_tensor_shard_orientation, false); + MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1_SMALL, config_shard_spec}; auto config_tensor_device = config_tensor.to(device, memory_config); + config_tensor_device.print(); tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer()); tt::DataFormat config_df = tt::DataFormat::RawUInt16; Buffer *config_buffer = config_tensor_device.buffer(); + auto config_buffer_page_size = config_buffer->page_size(); uint32_t config_cb_id = tt::CB::c_in2; - auto config_cb_config = CircularBufferConfig(config_buffer->size(), {{config_cb_id, config_df}}) + auto config_cb_config = CircularBufferConfig(config_buffer_page_size, {{config_cb_id, config_df}}) .set_page_size(config_cb_id, config_buffer->page_size()) .set_globally_allocated_address(*config_buffer); CBHandle config_cb = CreateCircularBuffer(program, all_cores, config_cb_config); From e39f225eafee035ee1e81da8315aa6500466fe1c Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Tue, 10 Dec 2024 16:04:04 +0000 Subject: [PATCH 03/11] Add support for block sharding. ToDo: commonize code. Signed-off-by: Nilaykumar Patel --- .../unit_tests/operations/test_upsample.py | 6 +- .../upsample_program_factory_multicore.cpp | 119 +++++++++++------- 2 files changed, 75 insertions(+), 50 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index fa57a486650..9afe7b7bd49 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -115,8 +115,8 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): [5, 64, 5, 5], [1, 128, 5, 8], [1, 32, 5, 4], - [7, 64, 128, 17], - [3, 64, 132, 19], + [1, 64, 128, 17], + [1, 64, 132, 19], ], ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) @@ -132,7 +132,7 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate # for j in range(input_shape[1]): # for k in range(input_shape[2]): # for l in range(input_shape[3]): - # input[i, j, k, l] = k * width + l + 1 + # input[i, j, k, l] = (k * width + l + 1) ## golden reference using torch scale_factor = (scale_h, scale_w) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index f0bcf187d85..43d1ab78d59 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -7,6 +7,7 @@ #include #include "buffers/buffer_constants.hpp" +#include "common/assert.hpp" #include "common/core_coord.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" @@ -20,29 +21,20 @@ using namespace tt::tt_metal; namespace ttnn::operations::upsample { using namespace tt; -Tensor create_config_tensor( +Tensor create_config_tensor_height_sharded( Device *device, - ShardSpec &input_shard_spec, + uint32_t input_nsticks_per_core, const uint32_t batch_size, const uint32_t in_h, const uint32_t in_w, const uint32_t scale_factor_h, - const uint32_t scale_factor_w, - TensorMemoryLayout shard_scheme, - uint32_t ncores_nhw, - uint32_t ncores_x) { + const uint32_t scale_factor_w) { std::vector config_vector; - uint32_t input_nsticks_per_core = input_shard_spec.shape[0]; + uint32_t ncores_x = device->compute_with_storage_grid_size().x; uint32_t in_core = 0; uint32_t w = 0; uint32_t curr_stick = 0; - if(shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED) { - ncores_x = 1; - ncores_nhw = 1; - } - uint32_t physical_core_x = device->compute_with_storage_grid_size().x; - - auto core_coords = device->worker_core_from_logical_core(CoreCoord(in_core % physical_core_x, in_core / physical_core_x)); + auto core_coords = device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); for (uint32_t b = 0; b < batch_size; b++) { for (uint32_t h = 0; h < in_h; h++) { for (uint32_t w = 0; w < in_w; w++) { @@ -50,7 +42,7 @@ Tensor create_config_tensor( curr_stick = 0; in_core++; core_coords = - device->worker_core_from_logical_core(CoreCoord(0, in_core)); + device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); } config_vector.insert(config_vector.end(), {core_coords.x, core_coords.y, curr_stick, 0}); curr_stick++; @@ -59,27 +51,50 @@ Tensor create_config_tensor( config_vector.insert(config_vector.end(), config_vector.end() - (4 * in_w), config_vector.end()); } } - // Copy for y direction - std::vector temp_config_vector; - /*auto prev_idx = 0;*/ - /*auto idx = 0;*/ - /*for(uint32_t i = 0; i < ncores_nhw; i++) {*/ - /* idx = 4 * (i+1) * input_nsticks_per_core * scale_factor_h;*/ - /* for(uint32_t j = 0; j < ncores_x; j++) {*/ - /* temp_config_vector.insert(temp_config_vector.end(), config_vector.begin() + prev_idx, config_vector.begin() + idx);*/ - /* }*/ - /* prev_idx = idx;*/ - /*}*/ - for(uint32_t i = 0; i < ncores_x; i++) { - /*TODO: Change take core x into considereation.*/ - temp_config_vector.insert(temp_config_vector.end(), config_vector.begin(), config_vector.end()); - } + uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; + Shape config_shape = Shape({config_vector.size() / elems_per_core, elems_per_core}); + auto config_buffer = owned_buffer::create(std::move(config_vector)); + Tensor config_tensor = Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); + return config_tensor; +} + +Tensor create_config_tensor_block_sharded( + Device *device, + uint32_t input_nsticks_per_core, + const uint32_t batch_size, + const uint32_t in_h, + const uint32_t in_w, + const uint32_t scale_factor_h, + const uint32_t scale_factor_w, + uint32_t ncores_x) { + std::vector config_vector; + uint32_t in_core = 0; + uint32_t w = 0; + uint32_t curr_stick = 0; + + CoreCoord core_coords; + for (uint32_t b = 0; b < batch_size; b++) { + for (uint32_t h = 0; h < in_h; h++) { + for (uint32_t w = 0; w < in_w; w++) { + if (curr_stick == input_nsticks_per_core) { + curr_stick = 0; + in_core++; + } + config_vector.insert(config_vector.end(), {in_core, curr_stick}); + curr_stick++; + } + for (uint32_t j = 0; j < scale_factor_h - 1; j++) + config_vector.insert(config_vector.end(), config_vector.end() - (2 * in_w), config_vector.end()); + } + } + std::vector temp_config_vector; - using namespace std; - uint32_t core = 0; - for(auto i = 0; i < temp_config_vector.size(); i+=4) { - cout << temp_config_vector[i] << " " << temp_config_vector[i+1] << " " << temp_config_vector[i+2] << " " << temp_config_vector[i+3] << endl; + for(uint32_t i = 0; i < ncores_x; i++) { + for(uint32_t j = 0; j < config_vector.size(); j+=2) { + core_coords = device->worker_core_from_logical_core(CoreCoord(i, config_vector[j])); + temp_config_vector.insert(temp_config_vector.end(), {core_coords.x, core_coords.y, config_vector[j+1], 0}); + } } uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; Shape config_shape = Shape({temp_config_vector.size() / elems_per_core, elems_per_core}); @@ -172,24 +187,34 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& log_debug(LogOp, "input_nsticks_per_core: {}, output_nsticks_per_core: {}", input_nsticks_per_core, output_nsticks_per_core); // create config tensor - Tensor config_tensor = create_config_tensor( - device, - shard_spec, - input.legacy_shape()[0], - input.legacy_shape()[1], - in_w, - scale_factor_h, - scale_factor_w, - input.memory_config().memory_layout, - ncores_nhw, - ncores_x); - config_tensor.print(); + Tensor config_tensor; + if(input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + config_tensor = create_config_tensor_block_sharded( + device, + shard_spec.shape[0], + input.legacy_shape()[0], + input.legacy_shape()[1], + in_w, + scale_factor_h, + scale_factor_w, + ncores_x); + } else if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + config_tensor = create_config_tensor_height_sharded( + device, + shard_spec.shape[0], + input.legacy_shape()[0], + input.legacy_shape()[1], + in_w, + scale_factor_h, + scale_factor_w); + } else { + TT_THROW("Unsupported sharding layout"); + } auto shard_shape = std::array({1, (uint32_t)config_tensor.get_shape()[-1]}); auto config_tensor_shard_orientation = input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED ? (shard_spec.orientation == ShardOrientation::COL_MAJOR ? ShardOrientation::ROW_MAJOR : ShardOrientation::COL_MAJOR) : ShardOrientation::ROW_MAJOR; ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, config_tensor_shard_orientation, false); MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1_SMALL, config_shard_spec}; auto config_tensor_device = config_tensor.to(device, memory_config); - config_tensor_device.print(); tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer()); tt::DataFormat config_df = tt::DataFormat::RawUInt16; From cbcf0424079a11f46828733b590aaee9e39ce2f5 Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Wed, 11 Dec 2024 12:47:20 +0000 Subject: [PATCH 04/11] Commonize code for block and height sharding. Signed-off-by: Nilaykumar Patel --- .../upsample_program_factory_multicore.cpp | 116 +++++++----------- 1 file changed, 42 insertions(+), 74 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 43d1ab78d59..0992b1ba1ad 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -21,45 +21,7 @@ using namespace tt::tt_metal; namespace ttnn::operations::upsample { using namespace tt; -Tensor create_config_tensor_height_sharded( - Device *device, - uint32_t input_nsticks_per_core, - const uint32_t batch_size, - const uint32_t in_h, - const uint32_t in_w, - const uint32_t scale_factor_h, - const uint32_t scale_factor_w) { - std::vector config_vector; - uint32_t ncores_x = device->compute_with_storage_grid_size().x; - uint32_t in_core = 0; - uint32_t w = 0; - uint32_t curr_stick = 0; - auto core_coords = device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); - for (uint32_t b = 0; b < batch_size; b++) { - for (uint32_t h = 0; h < in_h; h++) { - for (uint32_t w = 0; w < in_w; w++) { - if (curr_stick == input_nsticks_per_core) { - curr_stick = 0; - in_core++; - core_coords = - device->worker_core_from_logical_core(CoreCoord(in_core % ncores_x, in_core / ncores_x)); - } - config_vector.insert(config_vector.end(), {core_coords.x, core_coords.y, curr_stick, 0}); - curr_stick++; - } - for (uint32_t j = 0; j < scale_factor_h - 1; j++) - config_vector.insert(config_vector.end(), config_vector.end() - (4 * in_w), config_vector.end()); - } - } - - uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; - Shape config_shape = Shape({config_vector.size() / elems_per_core, elems_per_core}); - auto config_buffer = owned_buffer::create(std::move(config_vector)); - Tensor config_tensor = Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); - return config_tensor; -} - -Tensor create_config_tensor_block_sharded( +static Tensor create_config_tensor_block_sharded( Device *device, uint32_t input_nsticks_per_core, const uint32_t batch_size, @@ -67,42 +29,55 @@ Tensor create_config_tensor_block_sharded( const uint32_t in_w, const uint32_t scale_factor_h, const uint32_t scale_factor_w, - uint32_t ncores_x) { + uint32_t ncores_x, + bool is_height_sharded) { std::vector config_vector; - uint32_t in_core = 0; - uint32_t w = 0; - uint32_t curr_stick = 0; + uint16_t in_core = 0, curr_stick = 0; + uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; - CoreCoord core_coords; - for (uint32_t b = 0; b < batch_size; b++) { - for (uint32_t h = 0; h < in_h; h++) { - for (uint32_t w = 0; w < in_w; w++) { - if (curr_stick == input_nsticks_per_core) { - curr_stick = 0; - in_core++; - } - config_vector.insert(config_vector.end(), {in_core, curr_stick}); - curr_stick++; + // Create map of core and respective offsets in input + for (uint32_t b = 0; b < batch_size; ++b) { + for (uint32_t h = 0; h < in_h; ++h) { + for (uint32_t w = 0; w < in_w; ++w, ++curr_stick) { + if (curr_stick == input_nsticks_per_core) curr_stick = 0, ++in_core; + config_vector.push_back(in_core); + config_vector.push_back(curr_stick); } - for (uint32_t j = 0; j < scale_factor_h - 1; j++) - config_vector.insert(config_vector.end(), config_vector.end() - (2 * in_w), config_vector.end()); + size_t row_size = 2 * in_w, initial_size = config_vector.size(); + for (uint32_t j = 1; j < scale_factor_h; ++j) + config_vector.insert(config_vector.end(), config_vector.end() - row_size, config_vector.end()); } } + std::vector temp_config_vector; - for(uint32_t i = 0; i < ncores_x; i++) { - for(uint32_t j = 0; j < config_vector.size(); j+=2) { - core_coords = device->worker_core_from_logical_core(CoreCoord(i, config_vector[j])); - temp_config_vector.insert(temp_config_vector.end(), {core_coords.x, core_coords.y, config_vector[j+1], 0}); + // Based on core calculate physical dimentions of cores + CoreCoord core_coords; + if (is_height_sharded) { + for (size_t j = 0; j < config_vector.size(); j += 2) { + core_coords = device->worker_core_from_logical_core(CoreCoord(config_vector[j] % ncores_x, config_vector[j] / ncores_x)); + temp_config_vector.push_back(core_coords.x); + temp_config_vector.push_back(core_coords.y); + temp_config_vector.push_back(config_vector[j + 1]); + temp_config_vector.push_back(0); + } + } else { + for (uint32_t i = 0; i < ncores_x; i++) { + for (size_t j = 0; j < config_vector.size(); j += 2) { + core_coords = device->worker_core_from_logical_core(CoreCoord(i, config_vector[j])); + temp_config_vector.push_back(core_coords.x); + temp_config_vector.push_back(core_coords.y); + temp_config_vector.push_back(config_vector[j + 1]); + temp_config_vector.push_back(0); + } } } - uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; - Shape config_shape = Shape({temp_config_vector.size() / elems_per_core, elems_per_core}); + Shape config_shape({temp_config_vector.size() / elems_per_core, elems_per_core}); auto config_buffer = owned_buffer::create(std::move(temp_config_vector)); - Tensor config_tensor = Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); - return config_tensor; + return Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); } + operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { Program program = CreateProgram(); Device *device = input.device(); @@ -151,6 +126,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // TODO: Support non-multiple case TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core); + TT_FATAL(shard_spec.orientation == ShardOrientation::ROW_MAJOR, "Input tensor is expected to have ROW_MAJOR shard orientation, got {}", shard_spec.orientation); // CBs @@ -188,7 +164,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // create config tensor Tensor config_tensor; - if(input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + if((input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) || (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) { config_tensor = create_config_tensor_block_sharded( device, shard_spec.shape[0], @@ -197,16 +173,8 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& in_w, scale_factor_h, scale_factor_w, - ncores_x); - } else if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - config_tensor = create_config_tensor_height_sharded( - device, - shard_spec.shape[0], - input.legacy_shape()[0], - input.legacy_shape()[1], - in_w, - scale_factor_h, - scale_factor_w); + ncores_x, + input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); } else { TT_THROW("Unsupported sharding layout"); } From 67eba82987193ceec1cb4e01eb0df20c749f4631 Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Thu, 12 Dec 2024 11:40:06 +0000 Subject: [PATCH 05/11] Clean up and add support for Column major shard orientation. Signed-off-by: Nilaykumar Patel --- .../unit_tests/operations/test_upsample.py | 10 +- .../upsample_program_factory_multicore.cpp | 133 +++++++++++------- 2 files changed, 91 insertions(+), 52 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index 9afe7b7bd49..4398aaaa73a 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn import ttnn -from models.utility_functions import skip_for_grayskull, skip_for_blackhole +from models.utility_functions import skip_for_grayskull, skip_for_blackhole, is_grayskull from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout @@ -119,11 +119,14 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): [1, 64, 132, 19], ], ) -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) @pytest.mark.parametrize("scale_h", [2, 3]) @pytest.mark.parametrize("scale_w", [2, 3]) @pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK]) -def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy): +@pytest.mark.parametrize("shard_orientation", [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR]) +def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy, shard_orientation): + if is_grayskull() and (scale_h > 2 or scale_w > 2): + pytest.skip("Skipping test because it won't fit in L1!") + ## input shape is N C H W batch_size, num_channels, height, width = input_shape torch.manual_seed(0) @@ -191,7 +194,6 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate # ) shard_grid = get_shard_grid_from_num_cores(device, ncores) - shard_orientation = ttnn.ShardOrientation.ROW_MAJOR if shard_strategy == ttnn.ShardStrategy.BLOCK: tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 0992b1ba1ad..3d5c5b1c734 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -2,10 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include #include +#include "buffers/buffer.hpp" #include "buffers/buffer_constants.hpp" #include "common/assert.hpp" #include "common/core_coord.hpp" @@ -21,63 +20,98 @@ using namespace tt::tt_metal; namespace ttnn::operations::upsample { using namespace tt; -static Tensor create_config_tensor_block_sharded( - Device *device, - uint32_t input_nsticks_per_core, +static Tensor create_config_tensor( + Device* device, + ShardSpec shard_spec, const uint32_t batch_size, const uint32_t in_h, const uint32_t in_w, const uint32_t scale_factor_h, const uint32_t scale_factor_w, - uint32_t ncores_x, - bool is_height_sharded) { - std::vector config_vector; + const uint32_t ncores_x, + const bool is_height_sharded, + const bool is_col_major) { uint16_t in_core = 0, curr_stick = 0; - uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; + const uint32_t input_nsticks_per_core = shard_spec.shape[0]; + + std::vector> core_range; + auto ranges = shard_spec.grid.ranges(); + // in case of height sharding and shards arranged in column major order, get cores where shard are placed. + if (is_col_major && is_height_sharded) { + for (auto i = 0; i < ranges.size(); i++) { + auto range = ranges[i]; + for (auto x = range.start_coord.x; x <= range.end_coord.x; x++) { + for (auto y = range.start_coord.y; y <= range.end_coord.y; y++) { + core_range.push_back({x, y}); + } + } + } + } + std::vector logical_core_to_stick_map; + size_t logical_core_to_stick_map_entry_size = 3; + size_t row_size = logical_core_to_stick_map_entry_size * in_w; // Create map of core and respective offsets in input for (uint32_t b = 0; b < batch_size; ++b) { for (uint32_t h = 0; h < in_h; ++h) { for (uint32_t w = 0; w < in_w; ++w, ++curr_stick) { - if (curr_stick == input_nsticks_per_core) curr_stick = 0, ++in_core; - config_vector.push_back(in_core); - config_vector.push_back(curr_stick); + if (curr_stick == input_nsticks_per_core) { + curr_stick = 0, ++in_core; + } + if (is_height_sharded && is_col_major) { + logical_core_to_stick_map.push_back(core_range[in_core][0]); + logical_core_to_stick_map.push_back(core_range[in_core][1]); + } else { + logical_core_to_stick_map.push_back(in_core); + logical_core_to_stick_map.push_back(0); + } + logical_core_to_stick_map.push_back(curr_stick); + } + for (uint32_t j = 1; j < scale_factor_h; ++j) { + logical_core_to_stick_map.insert( + logical_core_to_stick_map.end(), + logical_core_to_stick_map.end() - row_size, + logical_core_to_stick_map.end()); } - size_t row_size = 2 * in_w, initial_size = config_vector.size(); - for (uint32_t j = 1; j < scale_factor_h; ++j) - config_vector.insert(config_vector.end(), config_vector.end() - row_size, config_vector.end()); } } - std::vector temp_config_vector; + std::vector config_vector; - // Based on core calculate physical dimentions of cores + // Based on core calculate physical location of cores CoreCoord core_coords; if (is_height_sharded) { - for (size_t j = 0; j < config_vector.size(); j += 2) { - core_coords = device->worker_core_from_logical_core(CoreCoord(config_vector[j] % ncores_x, config_vector[j] / ncores_x)); - temp_config_vector.push_back(core_coords.x); - temp_config_vector.push_back(core_coords.y); - temp_config_vector.push_back(config_vector[j + 1]); - temp_config_vector.push_back(0); + for (size_t j = 0; j < logical_core_to_stick_map.size(); j += logical_core_to_stick_map_entry_size) { + CoreCoord core_coords; + if (is_col_major) { + core_coords = device->worker_core_from_logical_core( + CoreCoord(logical_core_to_stick_map[j], logical_core_to_stick_map[j + 1])); + } else { + core_coords = device->worker_core_from_logical_core( + CoreCoord(logical_core_to_stick_map[j] % ncores_x, logical_core_to_stick_map[j] / ncores_x)); + } + config_vector.push_back(core_coords.x); + config_vector.push_back(core_coords.y); + config_vector.push_back(logical_core_to_stick_map[j + 2]); + config_vector.push_back(0); } } else { - for (uint32_t i = 0; i < ncores_x; i++) { - for (size_t j = 0; j < config_vector.size(); j += 2) { - core_coords = device->worker_core_from_logical_core(CoreCoord(i, config_vector[j])); - temp_config_vector.push_back(core_coords.x); - temp_config_vector.push_back(core_coords.y); - temp_config_vector.push_back(config_vector[j + 1]); - temp_config_vector.push_back(0); + for (size_t i = 0; i < ncores_x; i++) { + for (size_t j = 0; j < logical_core_to_stick_map.size(); j += logical_core_to_stick_map_entry_size) { + core_coords = device->worker_core_from_logical_core(CoreCoord(i, logical_core_to_stick_map[j])); + config_vector.push_back(core_coords.x); + config_vector.push_back(core_coords.y); + config_vector.push_back(logical_core_to_stick_map[j + 2]); + config_vector.push_back(0); } } } - Shape config_shape({temp_config_vector.size() / elems_per_core, elems_per_core}); - auto config_buffer = owned_buffer::create(std::move(temp_config_vector)); + uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; + Shape config_shape({config_vector.size() / elems_per_core, elems_per_core}); + auto config_buffer = owned_buffer::create(std::move(config_vector)); return Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); } - operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { Program program = CreateProgram(); Device *device = input.device(); @@ -126,7 +160,6 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // TODO: Support non-multiple case TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core); - TT_FATAL(shard_spec.orientation == ShardOrientation::ROW_MAJOR, "Input tensor is expected to have ROW_MAJOR shard orientation, got {}", shard_spec.orientation); // CBs @@ -164,29 +197,33 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& // create config tensor Tensor config_tensor; - if((input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) || (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) { - config_tensor = create_config_tensor_block_sharded( - device, - shard_spec.shape[0], - input.legacy_shape()[0], - input.legacy_shape()[1], - in_w, - scale_factor_h, - scale_factor_w, - ncores_x, - input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED); + if ((input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) || + (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) { + config_tensor = create_config_tensor( + device, + shard_spec, + input.legacy_shape()[0], + input.legacy_shape()[1], + in_w, + scale_factor_h, + scale_factor_w, + ncores_x, + input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + shard_spec.orientation == ShardOrientation::COL_MAJOR); } else { TT_THROW("Unsupported sharding layout"); } auto shard_shape = std::array({1, (uint32_t)config_tensor.get_shape()[-1]}); - auto config_tensor_shard_orientation = input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED ? (shard_spec.orientation == ShardOrientation::COL_MAJOR ? ShardOrientation::ROW_MAJOR : ShardOrientation::COL_MAJOR) : ShardOrientation::ROW_MAJOR; + auto config_tensor_shard_orientation = input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED + ? ShardOrientation::COL_MAJOR + : shard_spec.orientation; ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, config_tensor_shard_orientation, false); - MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1_SMALL, config_shard_spec}; + MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1, config_shard_spec}; auto config_tensor_device = config_tensor.to(device, memory_config); tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer()); tt::DataFormat config_df = tt::DataFormat::RawUInt16; - Buffer *config_buffer = config_tensor_device.buffer(); + Buffer* config_buffer = config_tensor_device.buffer(); auto config_buffer_page_size = config_buffer->page_size(); uint32_t config_cb_id = tt::CB::c_in2; auto config_cb_config = CircularBufferConfig(config_buffer_page_size, {{config_cb_id, config_df}}) From 6fcb4712da552c91565c3a7896de86ba66466a20 Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Fri, 13 Dec 2024 08:41:47 +0000 Subject: [PATCH 06/11] Align code Signed-off-by: Nilaykumar Patel --- .../upsample_program_factory_multicore.cpp | 79 +++++++++++-------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 3d5c5b1c734..3165c85e97c 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -13,7 +13,6 @@ #include "tt_metal/host_api.hpp" #include "tt_metal/common/math.hpp" - using namespace tt::constants; using namespace tt::tt_metal; @@ -112,9 +111,10 @@ static Tensor create_config_tensor( return Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); } -operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { +operation::ProgramWithCallbacks upsample_multi_core( + const Tensor& input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { Program program = CreateProgram(); - Device *device = input.device(); + Device* device = input.device(); tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); @@ -139,27 +139,36 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& uint32_t ncores_nhw = ncores; auto out_shard_spec = output.shard_spec().value(); - TT_FATAL(out_shard_spec.num_cores() == ncores, "Output tensor should have same number of cores {} as input tensor {}", out_shard_spec.num_cores(), ncores); + TT_FATAL( + out_shard_spec.num_cores() == ncores, + "Output tensor should have same number of cores {} as input tensor {}", + out_shard_spec.num_cores(), + ncores); uint32_t in_nsticks_per_core = shard_spec.shape[0]; uint32_t out_nsticks_per_core = in_nsticks_per_core * scale_factor_h * scale_factor_w; + if (input.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { + TT_THROW("Unsupported sharding layout"); + } + // extra limitation to avoid post upsample step of resharding - if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - } else if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { ncores_x = all_cores.ranges().begin()->end_coord.x + 1; ncores_nhw = all_cores.ranges().begin()->end_coord.y + 1; input_stick_nbytes = input_stick_nbytes / ncores_x; output_stick_nbytes = output_stick_nbytes / ncores_x; - } else { - TT_THROW("Unsupported sharding layout"); } uint32_t input_nsticks_per_core = div_up(input_nsticks, ncores_nhw); uint32_t output_nsticks_per_core = div_up(output_nsticks, ncores_nhw); // TODO: Support non-multiple case - TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core); + TT_FATAL( + in_nsticks_per_core == input_nsticks_per_core, + "Input sticks per shard {} should be same as input sticks per core {}", + in_nsticks_per_core, + input_nsticks_per_core); // CBs @@ -170,11 +179,10 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& uint32_t aligned_input_stick_nbytes = round_up_to_mul32(input_stick_nbytes); uint32_t in_cb_pagesize = aligned_input_stick_nbytes; uint32_t in_cb_npages = input_nsticks_per_core * buffering_factor; - CircularBufferConfig cb_src0_config = CircularBufferConfig( - in_cb_pagesize * in_cb_npages, - {{in_cb_id, input_cb_data_format}}) - .set_page_size(in_cb_id, in_cb_pagesize) - .set_globally_allocated_address(*input.buffer()); + CircularBufferConfig cb_src0_config = + CircularBufferConfig(in_cb_pagesize * in_cb_npages, {{in_cb_id, input_cb_data_format}}) + .set_page_size(in_cb_id, in_cb_pagesize) + .set_globally_allocated_address(*input.buffer()); auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); // output sharded CB with upsampled data @@ -182,18 +190,21 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& uint32_t aligned_output_stick_nbytes = round_up_to_mul32(output_stick_nbytes); uint32_t out_cb_pagesize = aligned_output_stick_nbytes; uint32_t out_cb_npages = output_nsticks_per_core * buffering_factor; - CircularBufferConfig out_cb_config = CircularBufferConfig( - out_cb_pagesize * out_cb_npages, - {{out_cb_id, output_cb_data_format}}) - .set_page_size(out_cb_id, out_cb_pagesize) - .set_globally_allocated_address(*output.buffer()); + CircularBufferConfig out_cb_config = + CircularBufferConfig(out_cb_pagesize * out_cb_npages, {{out_cb_id, output_cb_data_format}}) + .set_page_size(out_cb_id, out_cb_pagesize) + .set_globally_allocated_address(*output.buffer()); auto out_cb = tt_metal::CreateCircularBuffer(program, all_cores, out_cb_config); log_debug(LogOp, "input_cb: {}, npages: {}, pagesize: {}", in_cb_id, in_cb_npages, in_cb_pagesize); log_debug(LogOp, "output_cb: {}, npages: {}, pagesize: {}", out_cb_id, out_cb_npages, out_cb_pagesize); log_debug(LogOp, "input_stick_nbytes: {}, output_stick_nbytes: {}", input_stick_nbytes, output_stick_nbytes); log_debug(LogOp, "ncores: {}, ncores_x: {}", ncores, ncores_x); - log_debug(LogOp, "input_nsticks_per_core: {}, output_nsticks_per_core: {}", input_nsticks_per_core, output_nsticks_per_core); + log_debug( + LogOp, + "input_nsticks_per_core: {}, output_nsticks_per_core: {}", + input_nsticks_per_core, + output_nsticks_per_core); // create config tensor Tensor config_tensor; @@ -239,7 +250,8 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& false, config_cb_id, }; - auto writer_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); + auto writer_kernel_fname = std::string( + "ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); auto writer_kernel = CreateKernel(program, writer_kernel_fname, all_cores, WriterDataMovementConfig(writer_compile_time_args)); @@ -249,7 +261,8 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& true, config_cb_id, }; - auto reader_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); + auto reader_kernel_fname = std::string( + "ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); auto reader_kernel = CreateKernel(program, reader_kernel_fname, all_cores, ReaderDataMovementConfig(reader_compile_time_args)); @@ -264,14 +277,14 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& writer_rt_args[2] = scale_factor_h; writer_rt_args[3] = scale_factor_w; writer_rt_args[4] = input_nsticks_per_core; - writer_rt_args[5] = output_nsticks_per_core / 2; // half of the outputs are processed by each core - writer_rt_args[6] = 0; // set for each core below + writer_rt_args[5] = output_nsticks_per_core / 2; // half of the outputs are processed by each core + writer_rt_args[6] = 0; // set for each core below uint32_t start_input_stick_id = 0; if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { for (int32_t core = 0; core < ncores_nhw; ++core) { for (int32_t core_x = 0; core_x < ncores_x; ++core_x) { - CoreCoord core_coord(core_x, core); // logical + CoreCoord core_coord(core_x, core); // logical writer_rt_args[6] = start_input_stick_id; SetRuntimeArgs(program, writer_kernel, core_coord, writer_rt_args); SetRuntimeArgs(program, reader_kernel, core_coord, writer_rt_args); @@ -280,7 +293,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& } } else if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { for (int32_t core = 0; core < ncores_nhw; ++core) { - CoreCoord core_coord(core % ncores_x, core / ncores_x); // logical + CoreCoord core_coord(core % ncores_x, core / ncores_x); // logical writer_rt_args[6] = start_input_stick_id; SetRuntimeArgs(program, writer_kernel, core_coord, writer_rt_args); SetRuntimeArgs(program, reader_kernel, core_coord, writer_rt_args); @@ -291,13 +304,11 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& } auto override_runtime_args_callback = [writer_kernel, cb_src0, config_cb, out_cb]( - const void* operation, - Program &program, - const std::vector& input_tensors, - const std::vector>&, - const std::vector& output_tensors - ) { - + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors) { auto src_buffer = input_tensors.at(0).buffer(); auto dst_buffer = output_tensors.at(0).buffer(); @@ -305,7 +316,7 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& UpdateDynamicCircularBufferAddress(program, out_cb, *dst_buffer); }; - return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } } // namespace ttnn::operations::upsample From fdb514b3578f3dfe54bf1c7930554313e8b33708 Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Fri, 13 Dec 2024 11:52:14 +0000 Subject: [PATCH 07/11] Make pipeline work. Signed-off-by: Nilaykumar Patel --- .../ttnn/unit_tests/operations/test_upsample.py | 1 + .../upsample_program_factory_multicore.cpp | 16 ++++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index d2795db29e8..7405bfda43e 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -119,6 +119,7 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): [1, 64, 132, 19], ], ) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) @pytest.mark.parametrize("scale_h", [2, 3]) @pytest.mark.parametrize("scale_w", [2, 3]) @pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK]) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 3165c85e97c..d8c11117032 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -4,15 +4,19 @@ #include -#include "buffers/buffer.hpp" -#include "buffers/buffer_constants.hpp" -#include "common/assert.hpp" -#include "common/core_coord.hpp" -#include "ttnn/tensor/host_buffer/functions.hpp" +#include + +#include "upsample_op.hpp" +#include "ttnn/operations/math.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" #include "tt_metal/common/math.hpp" +#include "tt_metal/tt_stl/reflection.hpp" +#include "ttnn/tensor/host_buffer/functions.hpp" + using namespace tt::constants; using namespace tt::tt_metal; @@ -229,7 +233,7 @@ operation::ProgramWithCallbacks upsample_multi_core( ? ShardOrientation::COL_MAJOR : shard_spec.orientation; ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, config_tensor_shard_orientation, false); - MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1, config_shard_spec}; + MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1_SMALL, config_shard_spec}; auto config_tensor_device = config_tensor.to(device, memory_config); tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer()); From 09b57cb1680e12b704c8ab838643f169292cc0ef Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Wed, 18 Dec 2024 17:50:29 +0000 Subject: [PATCH 08/11] Make pipeline pass. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Suggested-by: Pavle Josipović Signed-off-by: Nilaykumar Patel --- .../experimental/functional_unet/tt/model_preprocessing.py | 5 ++--- .../kernels/dataflow/writer_upsample_multi_core_sharded.cpp | 3 +-- .../upsample/device/upsample_program_factory_multicore.cpp | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/models/experimental/functional_unet/tt/model_preprocessing.py b/models/experimental/functional_unet/tt/model_preprocessing.py index ff77e0083fa..1520ab454bf 100644 --- a/models/experimental/functional_unet/tt/model_preprocessing.py +++ b/models/experimental/functional_unet/tt/model_preprocessing.py @@ -46,11 +46,11 @@ def create_unet_model_parameters( for key in parameters.keys(): parameters[key].module = getattr(model, key) - parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 16 * 32} + parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32} parameters.c1["use_split_reader"] = True parameters.c1["use_activation_double_buffer"] = True parameters.c1["input_channels_alignment"] = 16 - parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 16 * 32} + parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32} parameters.c1_2["use_split_reader"] = True parameters.c1_2["use_activation_double_buffer"] = True parameters.c1_2["input_channels_alignment"] = 16 @@ -136,7 +136,6 @@ def create_unet_model_parameters( parameters.c8_3["use_split_reader"] = True parameters.c8_3["input_channels_alignment"] = 16 - parameters.output_layer["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 16 * 32} parameters.output_layer["use_activation_double_buffer"] = True parameters.output_layer["use_split_reader"] = True parameters.output_layer["input_channels_alignment"] = 16 diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp index 0fbee10ad5a..acaf9ba284a 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp @@ -66,7 +66,6 @@ void kernel_main() { } } - cb_push_back(out_cb_id, out_w); - noc_async_read_barrier(); + cb_push_back(out_cb_id, out_w); } diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index d8c11117032..20dd9a8f9fc 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -240,7 +240,7 @@ operation::ProgramWithCallbacks upsample_multi_core( tt::DataFormat config_df = tt::DataFormat::RawUInt16; Buffer* config_buffer = config_tensor_device.buffer(); auto config_buffer_page_size = config_buffer->page_size(); - uint32_t config_cb_id = tt::CB::c_in2; + uint32_t config_cb_id = CBIndex::c_6; auto config_cb_config = CircularBufferConfig(config_buffer_page_size, {{config_cb_id, config_df}}) .set_page_size(config_cb_id, config_buffer->page_size()) .set_globally_allocated_address(*config_buffer); @@ -307,7 +307,7 @@ operation::ProgramWithCallbacks upsample_multi_core( TT_THROW("Unsupported memory layout"); } - auto override_runtime_args_callback = [writer_kernel, cb_src0, config_cb, out_cb]( + auto override_runtime_args_callback = [writer_kernel, cb_src0, out_cb, config_cb]( const void* operation, Program& program, const std::vector& input_tensors, From ece940a3e7b5818a05e3b50f3fb998707444ba46 Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Fri, 20 Dec 2024 09:14:28 +0000 Subject: [PATCH 09/11] Address review comments. Signed-off-by: Nilaykumar Patel --- .../unit_tests/operations/test_upsample.py | 5 --- .../writer_upsample_multi_core_sharded.cpp | 31 +++++-------------- .../upsample_program_factory_multicore.cpp | 15 +++++---- 3 files changed, 16 insertions(+), 35 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index 7405bfda43e..4cbe29ec4d9 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -132,11 +132,6 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate batch_size, num_channels, height, width = input_shape torch.manual_seed(0) input = torch.rand(input_shape, dtype=torch.bfloat16) - # for i in range(input_shape[0]): - # for j in range(input_shape[1]): - # for k in range(input_shape[2]): - # for l in range(input_shape[3]): - # input[i, j, k, l] = (k * width + l + 1) ## golden reference using torch scale_factor = (scale_h, scale_w) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp index acaf9ba284a..7b221b6c384 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp @@ -3,32 +3,13 @@ // SPDX-License-Identifier: Apache-2.0 #include - #include "dataflow_api.h" -#define ENABLE_DEBUG_PRINT 0 - -#if ENABLE_DEBUG_PRINT == 1 -#include "debug/dprint.h" - -inline void print_pages(uint32_t l1_addr, uint32_t pagelen, uint32_t npages, uint32_t start = 0) { - volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast(l1_addr) + start * pagelen; - for (uint32_t page = 0; page < npages; ++page) { - DPRINT << start + page << ": "; - for (uint32_t j = 0; j < pagelen; ++j, ++ptr) { - DPRINT << BF16(*ptr) << " "; - } - DPRINT << ENDL(); - } -} -#endif void kernel_main() { uint32_t stick_nbytes = get_arg_val(0); uint32_t in_nsticks_per_core = get_arg_val(1); uint32_t scale_h = get_arg_val(2); uint32_t scale_w = get_arg_val(3); - uint32_t in_w = get_arg_val(4); - uint32_t out_w = get_arg_val(5); constexpr uint32_t in_cb_id = get_compile_time_arg_val(0); constexpr uint32_t out_cb_id = get_compile_time_arg_val(1); @@ -36,7 +17,7 @@ void kernel_main() { constexpr uint32_t config_cb_id = get_compile_time_arg_val(3); uint32_t reader_nsticks_per_core = (in_nsticks_per_core + is_reader) / 2; - uint32_t writer_nsticks_per_core = in_nsticks_per_core / 2; + uint32_t out_nsticks_per_core = reader_nsticks_per_core * scale_h * scale_w; uint32_t image_row_begin = is_reader ? 0 : reader_nsticks_per_core; uint32_t image_row_end = is_reader ? reader_nsticks_per_core : in_nsticks_per_core; uint32_t l1_read_addr = get_read_ptr(in_cb_id); @@ -46,10 +27,12 @@ void kernel_main() { volatile tt_l1_ptr uint16_t* config_data = reinterpret_cast(config_l1_addr); uint32_t reader_idx = 0; - if (!is_reader) { - reader_idx = 4 * (scale_h * image_row_begin); + if constexpr (!is_reader) { + /* For each input stick there are 4 entries in config cb {core_coords.x, core_coords.y, stick_offset(in + * input_cb), 0(padding)} so multiply input image_row_begin with (4 * scale_h) */ + reader_idx = (4 * scale_h) * image_row_begin; } - cb_reserve_back(out_cb_id, out_w); + cb_reserve_back(out_cb_id, out_nsticks_per_core); for (uint32_t row_begin = image_row_begin; row_begin < image_row_end; ++row_begin) { for (uint32_t sh = 0; sh < scale_h; sh++) { @@ -67,5 +50,5 @@ void kernel_main() { } noc_async_read_barrier(); - cb_push_back(out_cb_id, out_w); + cb_push_back(out_cb_id, out_nsticks_per_core); } diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 20dd9a8f9fc..d6187da0178 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -85,7 +85,6 @@ static Tensor create_config_tensor( CoreCoord core_coords; if (is_height_sharded) { for (size_t j = 0; j < logical_core_to_stick_map.size(); j += logical_core_to_stick_map_entry_size) { - CoreCoord core_coords; if (is_col_major) { core_coords = device->worker_core_from_logical_core( CoreCoord(logical_core_to_stick_map[j], logical_core_to_stick_map[j + 1])); @@ -109,7 +108,15 @@ static Tensor create_config_tensor( } } } - uint32_t elems_per_core = 4 * scale_factor_h * input_nsticks_per_core; + /* Each entry in config_vector contains 4 elements: + * {core_coords.x, core_coords.y, stick_offset(in input_cb), 0(padding)} + * - core_coords.x: X coordinate of the core + * - core_coords.y: Y coordinate of the core + * - stick_offset: Offset within the input circular buffer + * - padding: Always set to 0 for alignment purposes + */ + const uint32_t config_buffer_entry_size = 4; + uint32_t elems_per_core = config_buffer_entry_size * scale_factor_h * input_nsticks_per_core; Shape config_shape({config_vector.size() / elems_per_core, elems_per_core}); auto config_buffer = owned_buffer::create(std::move(config_vector)); return Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); @@ -150,7 +157,6 @@ operation::ProgramWithCallbacks upsample_multi_core( ncores); uint32_t in_nsticks_per_core = shard_spec.shape[0]; - uint32_t out_nsticks_per_core = in_nsticks_per_core * scale_factor_h * scale_factor_w; if (input.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { TT_THROW("Unsupported sharding layout"); @@ -280,9 +286,6 @@ operation::ProgramWithCallbacks upsample_multi_core( writer_rt_args[1] = input_nsticks_per_core; writer_rt_args[2] = scale_factor_h; writer_rt_args[3] = scale_factor_w; - writer_rt_args[4] = input_nsticks_per_core; - writer_rt_args[5] = output_nsticks_per_core / 2; // half of the outputs are processed by each core - writer_rt_args[6] = 0; // set for each core below uint32_t start_input_stick_id = 0; if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { From 75973e19b01b4bf2a2028700223fa9ae74fe1c06 Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Tue, 24 Dec 2024 11:34:35 +0000 Subject: [PATCH 10/11] Reduce size of required CB for config buffer by half. Signed-off-by: Nilaykumar Patel --- .../writer_upsample_multi_core_sharded.cpp | 12 +++++------ .../upsample_program_factory_multicore.cpp | 21 ++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp index 7b221b6c384..44876387e1c 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp @@ -28,18 +28,18 @@ void kernel_main() { uint32_t reader_idx = 0; if constexpr (!is_reader) { - /* For each input stick there are 4 entries in config cb {core_coords.x, core_coords.y, stick_offset(in - * input_cb), 0(padding)} so multiply input image_row_begin with (4 * scale_h) */ - reader_idx = (4 * scale_h) * image_row_begin; + /* For each input stick there are 2 entries in config cb {{core_coords.x, core_coords.y}, stick_offset(in + * input_cb)} so multiply input image_row_begin with (2 * scale_h) */ + reader_idx = (2 * scale_h) * image_row_begin; } cb_reserve_back(out_cb_id, out_nsticks_per_core); for (uint32_t row_begin = image_row_begin; row_begin < image_row_end; ++row_begin) { for (uint32_t sh = 0; sh < scale_h; sh++) { - uint16_t corex = config_data[reader_idx++]; - uint16_t corey = config_data[reader_idx++]; + uint16_t cores = config_data[reader_idx++]; + uint16_t corey = cores & 0xFF; + uint16_t corex = cores >> 8; uint16_t offset = config_data[reader_idx++]; - reader_idx++; uint64_t src_remote_addr = get_noc_addr(corex, corey, l1_read_addr + offset * stick_nbytes); // replicate stick scale_w times. for (uint32_t sw = 0; sw < scale_w; sw++) { diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index d6187da0178..fe27097ad21 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -92,30 +92,31 @@ static Tensor create_config_tensor( core_coords = device->worker_core_from_logical_core( CoreCoord(logical_core_to_stick_map[j] % ncores_x, logical_core_to_stick_map[j] / ncores_x)); } - config_vector.push_back(core_coords.x); - config_vector.push_back(core_coords.y); + // Combine the x and y coordinates of the core into a single 16-bit value. + // The x coordinate is shifted left by 8 bits and added to the y coordinate. + uint16_t cores = (core_coords.x << 8) + core_coords.y; + config_vector.push_back(cores); config_vector.push_back(logical_core_to_stick_map[j + 2]); - config_vector.push_back(0); } } else { for (size_t i = 0; i < ncores_x; i++) { for (size_t j = 0; j < logical_core_to_stick_map.size(); j += logical_core_to_stick_map_entry_size) { core_coords = device->worker_core_from_logical_core(CoreCoord(i, logical_core_to_stick_map[j])); - config_vector.push_back(core_coords.x); - config_vector.push_back(core_coords.y); + // Combine the x and y coordinates of the core into a single 16-bit value. + // The x coordinate is shifted left by 8 bits and added to the y coordinate. + uint16_t cores = (core_coords.x << 8) + core_coords.y; + config_vector.push_back(cores); config_vector.push_back(logical_core_to_stick_map[j + 2]); - config_vector.push_back(0); } } } - /* Each entry in config_vector contains 4 elements: - * {core_coords.x, core_coords.y, stick_offset(in input_cb), 0(padding)} + /* Each entry in config_vector contains 2 elements: + * {{core_coords.x, core_coords.y}, stick_offset(in input_cb)} * - core_coords.x: X coordinate of the core * - core_coords.y: Y coordinate of the core * - stick_offset: Offset within the input circular buffer - * - padding: Always set to 0 for alignment purposes */ - const uint32_t config_buffer_entry_size = 4; + const uint32_t config_buffer_entry_size = 2; uint32_t elems_per_core = config_buffer_entry_size * scale_factor_h * input_nsticks_per_core; Shape config_shape({config_vector.size() / elems_per_core, elems_per_core}); auto config_buffer = owned_buffer::create(std::move(config_vector)); From 3036426907ee370de6568a09dc4a5137ce2a18cb Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Thu, 9 Jan 2025 10:56:13 +0000 Subject: [PATCH 11/11] Update perf number for unet and rebase. Signed-off-by: Nilaykumar Patel --- models/experimental/functional_unet/tests/test_unet_perf.py | 2 +- .../pool/upsample/device/upsample_program_factory_multicore.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/models/experimental/functional_unet/tests/test_unet_perf.py b/models/experimental/functional_unet/tests/test_unet_perf.py index b2d022ff08a..dcde736b46c 100644 --- a/models/experimental/functional_unet/tests/test_unet_perf.py +++ b/models/experimental/functional_unet/tests/test_unet_perf.py @@ -34,7 +34,7 @@ @pytest.mark.models_device_performance_bare_metal @pytest.mark.parametrize( "batch, groups, expected_device_perf_fps", - ((1, 2, 1053.0),), + ((1, 2, 1040.0),), ) def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float): command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]" diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 6565092cec2..a6f4de9c881 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -24,7 +24,7 @@ namespace ttnn::operations::upsample { using namespace tt; static Tensor create_config_tensor( - Device* device, + IDevice* device, ShardSpec shard_spec, const uint32_t batch_size, const uint32_t in_h,