diff --git a/models/experimental/functional_unet/tests/test_unet_perf.py b/models/experimental/functional_unet/tests/test_unet_perf.py index b2d022ff08a..dcde736b46c 100644 --- a/models/experimental/functional_unet/tests/test_unet_perf.py +++ b/models/experimental/functional_unet/tests/test_unet_perf.py @@ -34,7 +34,7 @@ @pytest.mark.models_device_performance_bare_metal @pytest.mark.parametrize( "batch, groups, expected_device_perf_fps", - ((1, 2, 1053.0),), + ((1, 2, 1040.0),), ) def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float): command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]" diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index e4a8846e3fc..4cbe29ec4d9 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn import ttnn -from models.utility_functions import skip_for_grayskull, skip_for_blackhole +from models.utility_functions import skip_for_grayskull, skip_for_blackhole, is_grayskull from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout @@ -109,12 +109,25 @@ def test_upsample_single_core(device, input_shapes, scale_h, scale_w): [1, 64, 132, 10], [1, 32, 8, 8], [2, 640, 32, 32], + # some random shapes + [1, 32, 5, 4], + [3, 32, 4, 4], + [5, 64, 5, 5], + [1, 128, 5, 8], + [1, 32, 5, 4], + [1, 64, 128, 17], + [1, 64, 132, 19], ], ) -@pytest.mark.parametrize("scale_h", [2]) -@pytest.mark.parametrize("scale_w", [2]) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize("scale_h", [2, 3]) +@pytest.mark.parametrize("scale_w", [2, 3]) @pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.BLOCK]) -def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy): +@pytest.mark.parametrize("shard_orientation", [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR]) +def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strategy, shard_orientation): + if is_grayskull() and (scale_h > 2 or scale_w > 2): + pytest.skip("Skipping test because it won't fit in L1!") + ## input shape is N C H W batch_size, num_channels, height, width = input_shape torch.manual_seed(0) @@ -136,15 +149,15 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate max_grid_size = (device_grid.y, device_grid.x) if shard_strategy == ttnn.ShardStrategy.HEIGHT: ## nsticks per shard should be divisible by in_w - max_nshards = min(batch_size * height, max_grid_size[0] * max_grid_size[1]) + max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1]) nshards = max_nshards while nshards > 0: - if batch_size * height % nshards == 0: + if batch_size * height * width % nshards == 0: break nshards -= 1 ncores = nshards elif shard_strategy == ttnn.ShardStrategy.BLOCK: - max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW + max_nshards_h = min(batch_size * height * width, max_grid_size[0]) ## height along NHW max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C ## find nshards_h along NHW nshards_h = max_nshards_h @@ -177,7 +190,6 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate # ) shard_grid = get_shard_grid_from_num_cores(device, ncores) - shard_orientation = ttnn.ShardOrientation.ROW_MAJOR if shard_strategy == ttnn.ShardStrategy.BLOCK: tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED @@ -351,6 +363,7 @@ def test_bilinear_multi_core( ## compare the results torch_result = torch_result.permute(0, 2, 3, 1) + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999) allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1) logger.info(pcc_msg) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp index 001da7bba6e..44876387e1c 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp @@ -7,59 +7,48 @@ void kernel_main() { uint32_t stick_nbytes = get_arg_val(0); - uint32_t in_image_rows_per_core = get_arg_val(1); + uint32_t in_nsticks_per_core = get_arg_val(1); uint32_t scale_h = get_arg_val(2); uint32_t scale_w = get_arg_val(3); - uint32_t in_w = get_arg_val(4); - uint32_t out_w = get_arg_val(5); constexpr uint32_t in_cb_id = get_compile_time_arg_val(0); constexpr uint32_t out_cb_id = get_compile_time_arg_val(1); constexpr uint32_t is_reader = get_compile_time_arg_val(2); - - uint32_t in_image_row_nbytes = in_w * stick_nbytes; - uint32_t out_image_row_nbytes = out_w * stick_nbytes; - uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2; - uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2; - uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core; - uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core; - uint32_t l1_read_addr = get_read_ptr(in_cb_id) + image_row_begin * in_image_row_nbytes; - uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * out_image_row_nbytes; - - cb_reserve_back(out_cb_id, out_w); - - // assuming shard begins with a new row. TODO: generalize? - for (uint32_t image_row = image_row_begin; image_row < image_row_end; ++image_row) { - uint32_t l1_write_addr_image_row_start = l1_write_addr; - for (uint32_t i = 0; i < in_w; ++i) { + constexpr uint32_t config_cb_id = get_compile_time_arg_val(3); + + uint32_t reader_nsticks_per_core = (in_nsticks_per_core + is_reader) / 2; + uint32_t out_nsticks_per_core = reader_nsticks_per_core * scale_h * scale_w; + uint32_t image_row_begin = is_reader ? 0 : reader_nsticks_per_core; + uint32_t image_row_end = is_reader ? reader_nsticks_per_core : in_nsticks_per_core; + uint32_t l1_read_addr = get_read_ptr(in_cb_id); + uint32_t l1_write_addr = get_write_ptr(out_cb_id) + image_row_begin * scale_h * scale_w * stick_nbytes; + + uint32_t config_l1_addr = get_read_ptr(config_cb_id); + volatile tt_l1_ptr uint16_t* config_data = reinterpret_cast(config_l1_addr); + + uint32_t reader_idx = 0; + if constexpr (!is_reader) { + /* For each input stick there are 2 entries in config cb {{core_coords.x, core_coords.y}, stick_offset(in + * input_cb)} so multiply input image_row_begin with (2 * scale_h) */ + reader_idx = (2 * scale_h) * image_row_begin; + } + cb_reserve_back(out_cb_id, out_nsticks_per_core); + + for (uint32_t row_begin = image_row_begin; row_begin < image_row_end; ++row_begin) { + for (uint32_t sh = 0; sh < scale_h; sh++) { + uint16_t cores = config_data[reader_idx++]; + uint16_t corey = cores & 0xFF; + uint16_t corex = cores >> 8; + uint16_t offset = config_data[reader_idx++]; + uint64_t src_remote_addr = get_noc_addr(corex, corey, l1_read_addr + offset * stick_nbytes); // replicate stick scale_w times. - for (uint32_t sw = 0; sw < scale_w; ++sw) { - // replicate stick scale_w times. - if constexpr (is_reader) { - uint64_t src_noc_addr = get_noc_addr(l1_read_addr); - noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); - } else { - uint64_t dst_noc_addr = get_noc_addr(l1_write_addr); - noc_async_write(l1_read_addr, dst_noc_addr, stick_nbytes); - } + for (uint32_t sw = 0; sw < scale_w; sw++) { + noc_async_read(src_remote_addr, l1_write_addr, stick_nbytes); l1_write_addr += stick_nbytes; } - l1_read_addr += stick_nbytes; } - - // Duplicate the whole image row in one shot - if constexpr (is_reader) { - uint64_t src_noc_addr = get_noc_addr(l1_write_addr_image_row_start); - noc_async_read(src_noc_addr, l1_write_addr, out_image_row_nbytes); - } else { - uint64_t dst_noc_addr = get_noc_addr(l1_write_addr); - noc_async_write(l1_write_addr_image_row_start, dst_noc_addr, out_image_row_nbytes); - } - l1_write_addr += out_image_row_nbytes; } - cb_push_back(out_cb_id, out_w); - - noc_async_write_barrier(); noc_async_read_barrier(); + cb_push_back(out_cb_id, out_nsticks_per_core); } diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 79eab51d7cf..a6f4de9c881 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 +#include + #include #include "upsample_op.hpp" @@ -13,6 +15,7 @@ #include "tt_metal/common/math.hpp" #include "tt_metal/tt_stl/reflection.hpp" +#include "ttnn/tensor/host_buffer/functions.hpp" using namespace tt::constants; using namespace tt::tt_metal; @@ -20,6 +23,106 @@ using namespace tt::tt_metal; namespace ttnn::operations::upsample { using namespace tt; +static Tensor create_config_tensor( + IDevice* device, + ShardSpec shard_spec, + const uint32_t batch_size, + const uint32_t in_h, + const uint32_t in_w, + const uint32_t scale_factor_h, + const uint32_t scale_factor_w, + const uint32_t ncores_x, + const bool is_height_sharded, + const bool is_col_major) { + uint16_t in_core = 0, curr_stick = 0; + const uint32_t input_nsticks_per_core = shard_spec.shape[0]; + + std::vector> core_range; + auto ranges = shard_spec.grid.ranges(); + // in case of height sharding and shards arranged in column major order, get cores where shard are placed. + if (is_col_major && is_height_sharded) { + for (auto i = 0; i < ranges.size(); i++) { + auto range = ranges[i]; + for (auto x = range.start_coord.x; x <= range.end_coord.x; x++) { + for (auto y = range.start_coord.y; y <= range.end_coord.y; y++) { + core_range.push_back({x, y}); + } + } + } + } + + std::vector logical_core_to_stick_map; + size_t logical_core_to_stick_map_entry_size = 3; + size_t row_size = logical_core_to_stick_map_entry_size * in_w; + // Create map of core and respective offsets in input + for (uint32_t b = 0; b < batch_size; ++b) { + for (uint32_t h = 0; h < in_h; ++h) { + for (uint32_t w = 0; w < in_w; ++w, ++curr_stick) { + if (curr_stick == input_nsticks_per_core) { + curr_stick = 0, ++in_core; + } + if (is_height_sharded && is_col_major) { + logical_core_to_stick_map.push_back(core_range[in_core][0]); + logical_core_to_stick_map.push_back(core_range[in_core][1]); + } else { + logical_core_to_stick_map.push_back(in_core); + logical_core_to_stick_map.push_back(0); + } + logical_core_to_stick_map.push_back(curr_stick); + } + for (uint32_t j = 1; j < scale_factor_h; ++j) { + logical_core_to_stick_map.insert( + logical_core_to_stick_map.end(), + logical_core_to_stick_map.end() - row_size, + logical_core_to_stick_map.end()); + } + } + } + + std::vector config_vector; + + // Based on core calculate physical location of cores + CoreCoord core_coords; + if (is_height_sharded) { + for (size_t j = 0; j < logical_core_to_stick_map.size(); j += logical_core_to_stick_map_entry_size) { + if (is_col_major) { + core_coords = device->worker_core_from_logical_core( + CoreCoord(logical_core_to_stick_map[j], logical_core_to_stick_map[j + 1])); + } else { + core_coords = device->worker_core_from_logical_core( + CoreCoord(logical_core_to_stick_map[j] % ncores_x, logical_core_to_stick_map[j] / ncores_x)); + } + // Combine the x and y coordinates of the core into a single 16-bit value. + // The x coordinate is shifted left by 8 bits and added to the y coordinate. + uint16_t cores = (core_coords.x << 8) + core_coords.y; + config_vector.push_back(cores); + config_vector.push_back(logical_core_to_stick_map[j + 2]); + } + } else { + for (size_t i = 0; i < ncores_x; i++) { + for (size_t j = 0; j < logical_core_to_stick_map.size(); j += logical_core_to_stick_map_entry_size) { + core_coords = device->worker_core_from_logical_core(CoreCoord(i, logical_core_to_stick_map[j])); + // Combine the x and y coordinates of the core into a single 16-bit value. + // The x coordinate is shifted left by 8 bits and added to the y coordinate. + uint16_t cores = (core_coords.x << 8) + core_coords.y; + config_vector.push_back(cores); + config_vector.push_back(logical_core_to_stick_map[j + 2]); + } + } + } + /* Each entry in config_vector contains 2 elements: + * {{core_coords.x, core_coords.y}, stick_offset(in input_cb)} + * - core_coords.x: X coordinate of the core + * - core_coords.y: Y coordinate of the core + * - stick_offset: Offset within the input circular buffer + */ + const uint32_t config_buffer_entry_size = 2; + uint32_t elems_per_core = config_buffer_entry_size * scale_factor_h * input_nsticks_per_core; + Shape config_shape({config_vector.size() / elems_per_core, elems_per_core}); + auto config_buffer = owned_buffer::create(std::move(config_vector)); + return Tensor(OwnedStorage{config_buffer}, config_shape, DataType::UINT16, Layout::ROW_MAJOR); +} + operation::ProgramWithCallbacks upsample_multi_core( const Tensor& input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { Program program = CreateProgram(); @@ -55,23 +158,17 @@ operation::ProgramWithCallbacks upsample_multi_core( ncores); uint32_t in_nsticks_per_core = shard_spec.shape[0]; - uint32_t out_nsticks_per_core = in_nsticks_per_core * scale_factor_h * scale_factor_w; + + if (input.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { + TT_THROW("Unsupported sharding layout"); + } // extra limitation to avoid post upsample step of resharding - if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { - TT_FATAL( - in_nsticks_per_core % in_w == 0, - "Restriction: Input sticks per core {} should be divisible by input width {}. TODO to remove this " - "restriction", - in_nsticks_per_core, - in_w); - } else if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { + if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { ncores_x = all_cores.ranges().begin()->end_coord.x + 1; ncores_nhw = all_cores.ranges().begin()->end_coord.y + 1; input_stick_nbytes = input_stick_nbytes / ncores_x; output_stick_nbytes = output_stick_nbytes / ncores_x; - } else { - TT_THROW("Unsupported sharding layout"); } uint32_t input_nsticks_per_core = div_up(input_nsticks, ncores_nhw); @@ -83,12 +180,6 @@ operation::ProgramWithCallbacks upsample_multi_core( "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core); - TT_FATAL( - out_nsticks_per_core == output_nsticks_per_core, - "Output sticks per shard {} should be same as output sticks per core {}", - out_nsticks_per_core, - output_nsticks_per_core); - TT_FATAL(input_nsticks_per_core % in_w == 0, "Error"); // CBs @@ -126,12 +217,49 @@ operation::ProgramWithCallbacks upsample_multi_core( input_nsticks_per_core, output_nsticks_per_core); + // create config tensor + Tensor config_tensor; + if ((input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) || + (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED)) { + config_tensor = create_config_tensor( + device, + shard_spec, + input.legacy_shape()[0], + input.legacy_shape()[1], + in_w, + scale_factor_h, + scale_factor_w, + ncores_x, + input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + shard_spec.orientation == ShardOrientation::COL_MAJOR); + } else { + TT_THROW("Unsupported sharding layout"); + } + auto shard_shape = std::array({1, (uint32_t)config_tensor.get_shape()[-1]}); + auto config_tensor_shard_orientation = input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED + ? ShardOrientation::COL_MAJOR + : shard_spec.orientation; + ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, config_tensor_shard_orientation, false); + MemoryConfig memory_config{input.memory_config().memory_layout, BufferType::L1_SMALL, config_shard_spec}; + auto config_tensor_device = config_tensor.to(device, memory_config); + tt::tt_metal::detail::AddConfigBuffer(program, config_tensor_device.device_buffer()); + + tt::DataFormat config_df = tt::DataFormat::RawUInt16; + Buffer* config_buffer = config_tensor_device.buffer(); + auto config_buffer_page_size = config_buffer->page_size(); + uint32_t config_cb_id = CBIndex::c_6; + auto config_cb_config = CircularBufferConfig(config_buffer_page_size, {{config_cb_id, config_df}}) + .set_page_size(config_cb_id, config_buffer->page_size()) + .set_globally_allocated_address(*config_buffer); + CBHandle config_cb = CreateCircularBuffer(program, all_cores, config_cb_config); + // Kernels std::vector writer_compile_time_args = { in_cb_id, out_cb_id, false, + config_cb_id, }; auto writer_kernel_fname = std::string( "ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); @@ -142,6 +270,7 @@ operation::ProgramWithCallbacks upsample_multi_core( in_cb_id, out_cb_id, true, + config_cb_id, }; auto reader_kernel_fname = std::string( "ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/writer_upsample_multi_core_sharded.cpp"); @@ -155,12 +284,9 @@ operation::ProgramWithCallbacks upsample_multi_core( uint32_t writer_nargs = 7; std::vector writer_rt_args(writer_nargs); writer_rt_args[0] = input_stick_nbytes; - writer_rt_args[1] = input_nsticks_per_core / in_w; + writer_rt_args[1] = input_nsticks_per_core; writer_rt_args[2] = scale_factor_h; writer_rt_args[3] = scale_factor_w; - writer_rt_args[4] = in_w; - writer_rt_args[5] = out_w; - writer_rt_args[6] = 0; // set for each core below uint32_t start_input_stick_id = 0; if (input.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { @@ -185,7 +311,7 @@ operation::ProgramWithCallbacks upsample_multi_core( TT_THROW("Unsupported memory layout"); } - auto override_runtime_args_callback = [writer_kernel, cb_src0, out_cb]( + auto override_runtime_args_callback = [writer_kernel, cb_src0, out_cb, config_cb]( const void* operation, Program& program, const std::vector& input_tensors,