From 1673fd938d200a56d438ff95eefa053b4a17504f Mon Sep 17 00:00:00 2001 From: Shwetank Singh Date: Thu, 29 Aug 2024 06:11:49 +0000 Subject: [PATCH 1/2] #5725: Adding bilinear support in upsample --- .../unit_tests/operations/test_upsample.py | 141 +++++++++- ttnn/CMakeLists.txt | 1 + .../device/kernels/compute/bilinear.cpp | 64 +++++ .../reader_bilinear_multi_core_sharded.cpp | 108 ++++++++ ...ple_bilinear_program_factory_multicore.cpp | 262 ++++++++++++++++++ .../pool/upsample/device/upsample_op.cpp | 33 ++- .../pool/upsample/device/upsample_op.hpp | 19 +- .../upsample_program_factory_multicore.cpp | 13 +- .../upsample_program_factory_singlecore.cpp | 25 +- .../operations/pool/upsample/upsample.cpp | 10 +- .../operations/pool/upsample/upsample.hpp | 4 +- .../pool/upsample/upsample_pybind.cpp | 3 +- .../sliding_window/sliding_window.cpp | 11 +- .../sliding_window/sliding_window.hpp | 1 + 14 files changed, 639 insertions(+), 56 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/compute/bilinear.cpp create mode 100644 ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp create mode 100644 ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp diff --git a/tests/ttnn/unit_tests/operations/test_upsample.py b/tests/ttnn/unit_tests/operations/test_upsample.py index ba18bf54dfc..86047a86581 100644 --- a/tests/ttnn/unit_tests/operations/test_upsample.py +++ b/tests/ttnn/unit_tests/operations/test_upsample.py @@ -4,13 +4,14 @@ import pytest import math +from loguru import logger from typing import Union, Tuple import torch import torch.nn as nn import ttnn - -from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import skip_for_grayskull, skip_for_blackhole +from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout TILE_WIDTH = 32 @@ -222,3 +223,139 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate assert allclose assert isclose assert isequal + + +@skip_for_grayskull() +@skip_for_blackhole() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "batch_size, num_channels, height, width, scale_h, scale_w", + ( + (1, 256, 16, 16, 8, 8), # 256x256 + (1, 256, 32, 32, 4, 4), # 256x256 + (1, 256, 64, 64, 2, 2), # 256x256 + (1, 256, 128, 128, 1, 1), # 256x256 + ), +) +@pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT]) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4, ttnn.MathFidelity.LoFi]) +@pytest.mark.parametrize("math_approx_mode", [True, False]) +def test_bilinear_multi_core( + device, + use_program_cache, + batch_size, + num_channels, + height, + width, + scale_h, + scale_w, + shard_strategy, + math_fidelity, + math_approx_mode, +): + ## input shape is N C H W + input_shape = [batch_size, num_channels, height, width] + torch.manual_seed(0) + input = torch.rand(input_shape, dtype=torch.bfloat16) + + ## golden reference using torch + scale_factor = (scale_h, scale_w) + torch_upsample = nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False) + torch_result = torch_upsample(input) + + ## permute to N H W C, which is what the upsample op expects + tt_input = input.permute(0, 2, 3, 1) + + num_bytes = 2 ## only BFLOAT16 is supported + + ## calculate ncores, corresponding grid_size and in_shard_shape based on the input_shape + ncores = None + device_grid = device.compute_with_storage_grid_size() + max_grid_size = (device_grid.y, device_grid.x) + if shard_strategy == ttnn.ShardStrategy.HEIGHT: + ## nsticks per shard should be divisible by in_w + max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1]) + nshards = max_nshards + while nshards > 0: + if batch_size * height * width % (nshards * TILE_WIDTH) == 0: + break + nshards -= 1 + ncores = nshards + elif shard_strategy == ttnn.ShardStrategy.BLOCK: + max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW + max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C + ## find nshards_h along NHW + nshards_h = max_nshards_h + while nshards_h > 0: + if batch_size * height % nshards_h == 0: + break + nshards_h -= 1 + ## find nshards_w along C + nshards_w = max_nshards_w + while nshards_w > 0: + ## make sure: 1. nshards_w divides num_channels, and 2. shard_shape[1] is aligned to 32B + if num_channels % nshards_w == 0 and math.ceil(num_channels * num_bytes / nshards_w) % TILE_WIDTH == 0: + break + nshards_w -= 1 + if nshards_w == 0 or nshards_h == 0: + raise ValueError("nshards_h or nshards_w is 0") + ncores = (nshards_h, nshards_w) + + shard_grid = get_shard_grid_from_num_cores(device, ncores) + shard_orientation = ttnn.ShardOrientation.ROW_MAJOR + + if shard_strategy == ttnn.ShardStrategy.BLOCK: + tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED + elif shard_strategy == ttnn.ShardStrategy.HEIGHT: + tensor_memory_layout = ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED + + ## input shard + if shard_strategy == ttnn.ShardStrategy.BLOCK: + shard_height = math.ceil(batch_size * height * width / ncores[0]) + shard_width = math.ceil(num_channels / ncores[1]) + elif shard_strategy == ttnn.ShardStrategy.HEIGHT: + shard_height = math.ceil(batch_size * height * width / ncores) + shard_width = num_channels + # breakpoint() + shard_shape = (shard_height, shard_width) + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, False) + in_sharded_mem_config = ttnn.MemoryConfig(tensor_memory_layout, ttnn.types.BufferType.L1, shard_spec) + + ## output shard + shard_height = shard_height * scale_h * scale_w + shard_shape = (shard_height, shard_width) + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, False) + + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=math_fidelity, + math_approx_mode=math_approx_mode, + fp32_dest_acc_en=False, + ) + + out_sharded_mem_config = ttnn.MemoryConfig(tensor_memory_layout, ttnn.types.BufferType.L1, shard_spec) + + logger.debug(f"in_shard_mem_config: {in_sharded_mem_config}") + logger.debug(f"out_shard_mem_config: {out_sharded_mem_config}") + + ## ttnn uses NHWC, so need to set scale_factor_c = 1 + scale_factor = (scale_h, scale_w, 1) + input_tensor = ttnn.from_torch(tt_input, device=device) + input_tensor = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config) + output_tensor = ttnn.upsample( + input_tensor, + scale_factor, + mode="bilinear", + memory_config=out_sharded_mem_config, + compute_kernel_config=compute_kernel_config, + ) + output_tensor = ttnn.to_memory_config(output_tensor, memory_config=ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.to_torch(output_tensor) + + ## compare the results + torch_result = torch_result.permute(0, 2, 3, 1) + passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999) + allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1) + logger.info(pcc_msg) + + assert allclose + assert passing diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 95723cbb324..289583a56ca 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -288,6 +288,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device//upsample_bilinear_program_factory_multicore.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_singlecore.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/upsample.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/compute/bilinear.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/compute/bilinear.cpp new file mode 100644 index 00000000000..c845f21f8b2 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/compute/bilinear.cpp @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/tilize.h" +#include "compute_kernel_api/reduce.h" +#include "compute_kernel_api/pack_untilize.h" + +template +inline void reduce_h_fused( + const uint32_t in_cb_id, + const uint32_t in_scalar_cb_id, + const uint32_t in_ntiles_hwc, + const uint32_t in_stick_index, + const uint32_t out_cb_id) { + + cb_reserve_back(out_cb_id, 1); + tile_regs_acquire(); + cb_wait_front(in_cb_id, 4); + unpack_tilizeA_B_block(in_cb_id, in_scalar_cb_id, in_ntiles_hwc, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, 2 /* unpack 1 or 2 faces ) */, unpA_face_r_dim); + for (uint32_t c_i = 0; c_i < in_ntiles_c; ++c_i) { + reduce_tile_math(c_i, 2 /* reduce 1 or 2 faces */); + } + cb_pop_front(in_cb_id, 4); + + tile_regs_wait(); + tile_regs_commit(); + pack_untilize_dst(out_cb_id, 1, 0, 1, 2); /* pack 1 row (1x16 or 1x32) */ + tile_regs_release(); + + cb_push_back(out_cb_id, 1); +} + +namespace NAMESPACE{ +void MAIN{ + constexpr uint32_t out_cb_id = tt::CB::c_out0; + constexpr uint32_t in1_cb_id = tt::CB::c_in1; + constexpr uint32_t bias_cb_id = tt::CB::c_in2; + constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; + constexpr uint32_t in2_cb_id = tt::CB::c_intermed0; + + constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0); + constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1); + constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2); + constexpr uint32_t window_size_hw = get_compile_time_arg_val(3); + constexpr uint32_t out_h = get_compile_time_arg_val(4); + constexpr uint32_t out_w = get_compile_time_arg_val(5); + constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7); + + constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(8); + constexpr uint32_t num_output_tiles = out_ntiles_c; //* nblocks; + + tilizeA_B_reduce_init(in1_cb_id, in_scalar_cb_id, in_ntiles_hwc, out_cb_id, 2, 4); + pack_untilize_dst_init_short(out_cb_id, 1, 2); /* pack 1 row (1x16 or 1x32) */ + for(uint32_t i = 0; i < nsticks_per_core_by_nblocks; i++){ + cb_wait_front(in_scalar_cb_id, 1); + reduce_h_fused(in1_cb_id, + in_scalar_cb_id, in_ntiles_hwc, i, out_cb_id); + cb_pop_front(in_scalar_cb_id, 1); + } +} // MAIN +} //NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp new file mode 100644 index 00000000000..38b5e1f4513 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp @@ -0,0 +1,108 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" + +#define ALWI inline __attribute__((always_inline)) + +// Fill given four values into the memory starting at the given address. +// WARNING: Use with caution as there's no memory protection. Make sure size is within limits +ALWI bool fill_four_val(uint32_t begin_addr, uint16_t val, uint16_t val1, uint16_t val2, uint16_t val3) { + volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast(begin_addr); + + ptr[0] = (val | (val1 << 16)); + ptr[1] = (val2 | (val3 << 16)); + return true; +} + + +void kernel_main() { + + uint32_t stick_nbytes = get_arg_val(0); + uint32_t in_image_rows_per_core = get_arg_val(1); + uint32_t scale_h = get_arg_val(2); + uint32_t scale_w = get_arg_val(3); + uint32_t in_w = get_arg_val(4); + uint32_t out_w = get_arg_val(5); + uint32_t src1_addr = get_arg_val(6); + uint32_t read_offset = get_arg_val(8); + uint32_t is_last_row = get_arg_val(9); + uint32_t in_h = 1; + constexpr bool src1_is_dram = false; + + constexpr uint32_t in_cb_id = get_compile_time_arg_val(0); + constexpr uint32_t out_cb_id = tt::CB::c_in1; + constexpr uint32_t is_reader = get_compile_time_arg_val(2); + + uint32_t in_image_row_nbytes = in_w * stick_nbytes; + uint32_t out_image_row_nbytes = out_w * stick_nbytes; + uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2; + uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2; + uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core; + uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core; + uint32_t l1_read_addr = get_read_ptr(in_cb_id); //+ image_row_begin * in_image_row_nbytes; + constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; + + // assuming shard begins with a new row. TODO: generalize? + float scale_h_inv = 1.0f / scale_h; + float scale_w_inv = 1.0f / scale_w; + float x, y, x_index, y_index, dx, dy; + y_index = (float)(0.5f) * (float)scale_h_inv + 0.5f; + for (uint32_t image_row = 0 ; image_row < in_image_rows_per_core * scale_h; ++image_row){ + x_index = (float)(0.5f) * (float)scale_w_inv -0.5f; + for(uint32_t j=0; j < in_w * scale_w; j++){ + cb_reserve_back(out_cb_id, 4); + cb_reserve_back(in_scalar_cb_id, 1); + + x = x_index < 0 ? 0 : x_index; + y = y_index < read_offset ? read_offset : y_index; + dx = x - int(x); + dy = y - int(y); + + uint32_t x1 = int(x); + uint32_t y1 = int(y); + uint32_t x2 = min(x1 + 1, in_w-1); + uint32_t y2 = y1 + 1; //, in_image_rows_per_core - 1); + if(is_last_row){ + y2 = min(y2, in_image_rows_per_core); //if last row, y2 should be in_image_rows_per_core + } + + fill_four_val(get_write_ptr(in_scalar_cb_id), float_to_bfloat16((1-dx) * (1-dy)), + float_to_bfloat16(dx * (1 - dy)), float_to_bfloat16((1 - dx) * dy), float_to_bfloat16(dx * dy)); + + uint32_t l1_write_addr = get_write_ptr(out_cb_id); + uint32_t l1_read_addr_temp = l1_read_addr + x1 * stick_nbytes + y1 * in_w * stick_nbytes; + //1st tile + uint64_t src_noc_addr = get_noc_addr(l1_read_addr_temp); + noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); + l1_write_addr += stick_nbytes; + + //2nd tile + l1_read_addr_temp = l1_read_addr + y1 * in_w * stick_nbytes + x2 * stick_nbytes; + src_noc_addr = get_noc_addr(l1_read_addr_temp); + noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); + l1_write_addr += stick_nbytes; + + //3rd tile + l1_read_addr_temp = l1_read_addr + y2 * in_w * stick_nbytes + x1 * stick_nbytes; + src_noc_addr = get_noc_addr(l1_read_addr_temp); + noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); + l1_write_addr += stick_nbytes; + + //4th tile + l1_read_addr_temp = l1_read_addr + y2 * in_w * stick_nbytes + x2 * stick_nbytes; + src_noc_addr = get_noc_addr(l1_read_addr_temp); + noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes); + l1_write_addr += stick_nbytes; + + //push scaler and data into cb. + noc_async_read_barrier(); + cb_push_back(out_cb_id, 4); + cb_push_back(in_scalar_cb_id, 1); + x_index += scale_w_inv; + } + y_index += scale_h_inv; + } +} diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp new file mode 100644 index 00000000000..44c8fcfe501 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp @@ -0,0 +1,262 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "upsample_op.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" + +#include "tt_metal/host_api.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/common/math.hpp" +//#include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/operations/reduction/generic/device/reduce_op.hpp" // for reduce_op_utils + +#include "tt_metal/tt_stl/reflection.hpp" +#include "ttnn/deprecated/tt_numpy/functions.hpp" +#include "ttnn/operations/sliding_window/sliding_window.hpp" +#include "ttnn/operations/sliding_window/halo/halo.hpp" + +#include "ttnn/operations/core/core.hpp" + +using namespace tt::constants; + +namespace ttnn::operations::upsample { +using namespace tt; +using sliding_window::SlidingWindowConfig; + +Tensor HaloTensorCreation(const Tensor &input){ + int batch_size = input.get_legacy_shape()[0]; + int input_height = input.get_legacy_shape()[1]; + int input_width = input.get_legacy_shape()[2]; + int num_cores_nhw = input.shard_spec().value().num_cores(); + + ttnn::Tensor input_tensor = input; // tensor to return + SlidingWindowConfig sliding_window_config = SlidingWindowConfig( + batch_size, + {input_height, input_width}, + {2, 2}, //kernel size + {1, 1}, // stride + {1, 0}, //padding + {1, 1}, //dilation + num_cores_nhw, + input_tensor.memory_config().shard_spec.value().grid, + false, true); + + input_tensor = ttnn::operations::core::reshape( + input_tensor, + Shape(std::array{ + 1, + 1, + input.get_shape()[0] * input.get_shape()[1] * input.get_shape()[2], + input.get_shape()[3]})); + + auto halo_output = ttnn::halo( + DefaultQueueId, + input_tensor, + sliding_window_config, + 0, + false, + false, + 0, + input_tensor.memory_config(), + false); + + return halo_output; +} + +operation::ProgramWithCallbacks bilinear_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w, const DeviceComputeKernelConfig compute_kernel_config) { + Program program = CreateProgram(); + Device *device = input.device(); + + auto input_shape = input.get_legacy_shape(); + auto output_shape = output.get_legacy_shape(); + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + + // NOTE: input is assumed to have channels last format: {N, H, W, C}, {N, 1, H * W, C}, {1, 1, N * H * W, C} + // NOTE: Bfp8_b/TILE is not yet supported + uint32_t input_stick_nbytes = input.get_legacy_shape()[-1] * input.element_size(); + uint32_t output_stick_nbytes = output.get_legacy_shape()[-1] * output.element_size(); + TT_FATAL(input_stick_nbytes == output_stick_nbytes, "Input and output sticks should have same size"); + + uint32_t output_nsticks = output.volume() / output.get_legacy_shape()[-1]; + uint32_t input_nsticks = input.volume() / input.get_legacy_shape()[-1]; + + uint32_t in_w = input.get_legacy_shape()[2]; + uint32_t out_w =output.get_legacy_shape()[2]; + + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); + + auto shard_spec = input.shard_spec().value(); + auto all_cores = shard_spec.grid; + uint32_t ncores = shard_spec.num_cores(); + uint32_t ncores_x = device->compute_with_storage_grid_size().x; + uint32_t ncores_nhw = ncores; + + auto out_shard_spec = output.shard_spec().value(); + TT_FATAL(out_shard_spec.num_cores() == ncores, "Output tensor should have same number of cores {} as input tensor {}", out_shard_spec.num_cores(), ncores); + + uint32_t in_nsticks_per_core = shard_spec.shape[0]; + uint32_t out_nsticks_per_core = in_nsticks_per_core * scale_factor_h * scale_factor_w; + + // extra limitation to avoid post upsample step of resharding + if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + TT_FATAL(in_nsticks_per_core % in_w == 0, "Restriction: Input sticks per core {} should be divisible by input width {}. TODO to remove this restriction", in_nsticks_per_core, in_w); + } else { + TT_FATAL(false, "Unsupported sharding layout"); + } + + uint32_t input_nsticks_per_core = div_up(input_nsticks, ncores_nhw); + uint32_t output_nsticks_per_core = div_up(output_nsticks, ncores_nhw); + + TT_FATAL(in_nsticks_per_core == input_nsticks_per_core, "Input sticks per shard {} should be same as input sticks per core {}", in_nsticks_per_core, input_nsticks_per_core); + TT_FATAL(out_nsticks_per_core == output_nsticks_per_core, "Output sticks per shard {} should be same as output sticks per core {}", out_nsticks_per_core, output_nsticks_per_core); + TT_FATAL(input_nsticks_per_core % in_w == 0, "Error"); + + //creating halo input tensor + auto halo_in = HaloTensorCreation(input); + auto halo_shard_shape = halo_in.shard_spec().value().shape; + + // CBs + uint32_t buffering_factor = 1; // data is already fully buffered in the CBs since its sharded + + // input data is in a sharded CB + uint32_t in_cb_id = CB::c_in0; + uint32_t aligned_input_stick_nbytes = round_up_to_mul32(input_stick_nbytes); + uint32_t in_cb_pagesize = aligned_input_stick_nbytes; + uint32_t in_cb_npages = halo_shard_shape[0] * buffering_factor; + CircularBufferConfig cb_src0_config = CircularBufferConfig( + in_cb_pagesize * in_cb_npages, + {{in_cb_id, input_cb_data_format}}) + .set_page_size(in_cb_id, in_cb_pagesize) + .set_globally_allocated_address(*halo_in.buffer()); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + + //intermediate tensor CB + uint32_t in1_cb_id = CB::c_in1; + CircularBufferConfig cb_src1_config = CircularBufferConfig( + 4 * in_cb_pagesize, //since 4 pixels per page are needed for intermediate tensor. + {{in1_cb_id, input_cb_data_format}}) + .set_page_size(in1_cb_id, in_cb_pagesize); + auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_src1_config); + + //scaler CB + uint32_t in_scalar_cb_id = CB::c_in4; + uint32_t in_scalar_cb_pagesize = tile_size(input_cb_data_format); + uint32_t in_scalar_cb_npages = 1; + CircularBufferConfig in_scalar_cb_config = + CircularBufferConfig(in_scalar_cb_npages * in_scalar_cb_pagesize, {{in_scalar_cb_id, input_cb_data_format}}) + .set_page_size(in_scalar_cb_id, in_scalar_cb_pagesize); + + + auto in_scalar_cb = tt_metal::CreateCircularBuffer(program, all_cores, in_scalar_cb_config); + + // output sharded CB with upsampled data + uint32_t out_cb_id = CB::c_out0; + uint32_t aligned_output_stick_nbytes = round_up_to_mul32(output_stick_nbytes); + uint32_t out_cb_pagesize = aligned_output_stick_nbytes; + uint32_t out_cb_npages = output_nsticks_per_core * buffering_factor; + CircularBufferConfig out_cb_config = CircularBufferConfig( + out_cb_pagesize * out_cb_npages, + {{out_cb_id, output_cb_data_format}}) + .set_page_size(out_cb_id, out_cb_pagesize) + .set_globally_allocated_address(*output.buffer()); + auto out_cb = tt_metal::CreateCircularBuffer(program, all_cores, out_cb_config); + + log_debug(LogOp, "input_cb: {}, npages: {}, pagesize: {}", in_cb_id, in_cb_npages, in_cb_pagesize); + log_debug(LogOp, "output_cb: {}, npages: {}, pagesize: {}", out_cb_id, out_cb_npages, out_cb_pagesize); + log_debug(LogOp, "input_stick_nbytes: {}, output_stick_nbytes: {}", input_stick_nbytes, output_stick_nbytes); + log_debug(LogOp, "ncores: {}, ncores_x: {}", ncores, ncores_x); + log_debug(LogOp, "input_nsticks_per_core: {}, output_nsticks_per_core: {}", input_nsticks_per_core, output_nsticks_per_core); + + // Kernels + std::vector reader_compile_time_args = { + in_cb_id, + out_cb_id, + false, + }; + + string writer_kernel_fname, reader_kernel_fname, compute_kernel_fname; + + reader_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp"); + compute_kernel_fname = std::string("ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/compute/bilinear.cpp"); + + uint32_t in_ntiles_c = (uint32_t)std::ceil((float)input_shape[3] / constants::TILE_WIDTH); + std::vector compute_compile_time_args = { + 1, + in_ntiles_c, + 1 * in_ntiles_c, + 4, + output_shape[1], + output_shape[2], + (uint32_t)std::ceil((float)output_shape[2] / constants::TILE_HEIGHT), + (uint32_t)std::ceil((float)output_shape[3] / constants::TILE_WIDTH), + output_nsticks_per_core, // loop count with blocks + input_shape[3], + }; + + auto reader_kernel = + CreateKernel(program, reader_kernel_fname, all_cores, ReaderDataMovementConfig(reader_compile_time_args)); + TT_FATAL(fp32_dest_acc_en == false, "fp32_dest_acc_en as true not supported. #12787 issue raised"); + auto reduce_op = ReduceOpMath::SUM; + auto reduce_dim = ReduceOpDim::H; + auto compute_config = ComputeConfig{ + .math_fidelity = math_fidelity, + .fp32_dest_acc_en = fp32_dest_acc_en, + .math_approx_mode = math_approx_mode, + .compile_args = compute_compile_time_args, + .defines = reduce_op_utils::get_defines(reduce_op, reduce_dim)}; + + auto compute_kernel = + CreateKernel(program, compute_kernel_fname, all_cores, compute_config); + + // runtime args + uint32_t reader_nargs = 10; + vector reader_rt_args(reader_nargs); + reader_rt_args[0] = input_stick_nbytes; + reader_rt_args[1] = input_nsticks_per_core / in_w; + reader_rt_args[2] = scale_factor_h; + reader_rt_args[3] = scale_factor_w; + reader_rt_args[4] = in_w; + reader_rt_args[5] = out_w; + reader_rt_args[6] = 0; // set for each core below + + uint32_t start_input_stick_id = 0; + + if (input.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { + for (int32_t core = 0; core < ncores_nhw; ++core) { + CoreCoord core_coord(core % ncores_x, core / ncores_x); // logical + reader_rt_args[6] = start_input_stick_id; + reader_rt_args[8] = (core == 0) ? 1 : 0; + reader_rt_args[9] = (core == ncores_nhw-1) ? 1 : 0; + SetRuntimeArgs(program, reader_kernel, core_coord, reader_rt_args); + start_input_stick_id += input_nsticks_per_core; + } + } else { + TT_FATAL(false, "Unsupported memory layout"); + } + + auto override_runtime_args_callback = [reader_kernel, cb_src0, out_cb]( + const void* operation, + Program &program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors + ) { + auto halo_in = HaloTensorCreation(input_tensors.at(0)); + auto src_buffer = halo_in.buffer(); + auto dst_buffer = output_tensors.at(0).buffer(); + + UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer); + UpdateDynamicCircularBufferAddress(program, out_cb, *dst_buffer); + }; + + return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; +} + +} // namespace ttnn::operations::upsample diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp index c193b4dce59..7e450c8c47c 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp @@ -13,8 +13,8 @@ #include "tt_metal/common/work_split.hpp" #include "tt_metal/host_api.hpp" -namespace tt { -namespace tt_metal { +namespace ttnn::operations::upsample { +using namespace tt; void UpSample::validate(const std::vector &input_tensors) const { const auto& input_tensor_a = input_tensors.at(0); @@ -26,7 +26,7 @@ void UpSample::validate(const std::vector &input_tensors) const { if (input_tensor_a.memory_config().is_sharded()) { TT_FATAL(input_tensor_a.memory_config().memory_layout == output_mem_config_.memory_layout, "Input tensor memory layout should be same as output tensor memory layout"); TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED, "Input tensor memory layout should be HEIGHT or BLOCK sharded"); - TT_FATAL(input_tensor_a.buffer()->buffer_type() == tt_metal::BufferType::L1, "Input buffer should be sharded in L1"); + TT_FATAL(input_tensor_a.buffer()->buffer_type() == tt::tt_metal::BufferType::L1, "Input buffer should be sharded in L1"); } } @@ -94,9 +94,19 @@ std::vector UpSample::create_output_tensors(const std::vector &i Tensor& output_tensor_0 = output_tensors.at(0); switch (get_parallelization_strategy(input_tensors)) { case UpSampleParallelizationStrategy::MULTI_CORE: - return upsample_multi_core(input_tensor_0, output_tensor_0, scale_factor_h_, scale_factor_w_); + if (mode_ == "bilinear") { + return bilinear_multi_core(input_tensor_0, output_tensor_0, scale_factor_h_, scale_factor_w_, this->compute_kernel_config_); + } else if(mode_ == "nearest") { + return upsample_multi_core(input_tensor_0, output_tensor_0, scale_factor_h_, scale_factor_w_); + } else { + TT_THROW("Unsupported mode"); + } case UpSampleParallelizationStrategy::SINGLE_CORE: - return upsample_single_core(input_tensor_0, output_tensor_0, scale_factor_h_, scale_factor_w_); + if(mode_ == "nearest") + return upsample_single_core(input_tensor_0, output_tensor_0, scale_factor_h_, scale_factor_w_); + else{ + TT_THROW("Unsupported mode"); + } }; return upsample_single_core(input_tensor_0, output_tensor_0, scale_factor_h_, scale_factor_w_); } @@ -109,15 +119,4 @@ UpSampleParallelizationStrategy UpSample::get_parallelization_strategy(const std return UpSampleParallelizationStrategy::SINGLE_CORE; } -Tensor upsample(const Tensor &input, - int scale_factor_h, - int scale_factor_w, - const MemoryConfig& out_mem_config) { - return operation::run_without_autoformat(UpSample{scale_factor_h, - scale_factor_w, - out_mem_config}, - {input}).at(0); -} - -} // namespace tt_metal -} // namespace tt +} // namespace ttnn::operations::upsample diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp index 39801f82287..721f654716f 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.hpp @@ -6,9 +6,9 @@ #include "ttnn/tensor/tensor.hpp" #include "ttnn/run_operation.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" -namespace tt { -namespace tt_metal { +namespace ttnn::operations::upsample { enum class UpSampleParallelizationStrategy { MULTI_CORE, SINGLE_CORE @@ -17,7 +17,9 @@ enum class UpSampleParallelizationStrategy { struct UpSample{ const int scale_factor_h_; const int scale_factor_w_; + const string mode_; const MemoryConfig output_mem_config_; + const DeviceComputeKernelConfig compute_kernel_config_; void validate(const std::vector &input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; @@ -26,13 +28,8 @@ struct UpSample{ UpSampleParallelizationStrategy get_parallelization_strategy(const std::vector &input_tensors) const; }; -Tensor upsample(const Tensor &input, - int scale_factor_h, - int scale_factor_w, - const MemoryConfig& out_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +operation::ProgramWithCallbacks upsample_single_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w); +operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w); +operation::ProgramWithCallbacks bilinear_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w, const DeviceComputeKernelConfig compute_kernel_config_); -operation::ProgramWithCallbacks upsample_single_core(const Tensor &input, Tensor& output, uint32_t scale_factor_h, uint32_t scale_factor_w); -operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, uint32_t scale_factor_h, uint32_t scale_factor_w); - -} // namespace tt_metal -} // namespace tt +} // namespace ttnn::operations::upsample diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index c99561e87f8..3e25bdd7cdf 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -16,15 +16,15 @@ using namespace tt::constants; -namespace tt { -namespace tt_metal { +namespace ttnn::operations::upsample { +using namespace tt; -operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, uint32_t scale_factor_h, uint32_t scale_factor_w) { +operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { Program program = CreateProgram(); Device *device = input.device(); - DataFormat input_cb_data_format = datatype_to_dataformat_converter(input.get_dtype()); - DataFormat output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); // NOTE: input is assumed to have channels last format: {N, H, W, C}, {N, 1, H * W, C}, {1, 1, N * H * W, C} // NOTE: Bfp8_b/TILE is not yet supported @@ -180,5 +180,4 @@ operation::ProgramWithCallbacks upsample_multi_core(const Tensor &input, Tensor& return {.program=std::move(program), .override_runtime_arguments_callback=override_runtime_args_callback}; } -} // namespace tt_metal -} // namespace tt +} // namespace ttnn::operations::upsample diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_singlecore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_singlecore.cpp index 675dc7524c0..426f1913439 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_singlecore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_singlecore.cpp @@ -16,11 +16,9 @@ using namespace tt::constants; -namespace tt { - -namespace tt_metal { - -operation::ProgramWithCallbacks upsample_single_core(const Tensor &input, Tensor& output, uint32_t scale_factor_h, uint32_t scale_factor_w) { +namespace ttnn::operations::upsample { +using namespace tt; +operation::ProgramWithCallbacks upsample_single_core(const Tensor &input, Tensor& output, const uint32_t scale_factor_h, const uint32_t scale_factor_w) { Program program{}; CoreRange core({0, 0}, {0, 0}); @@ -118,14 +116,16 @@ operation::ProgramWithCallbacks upsample_single_core(const Tensor &input, Tensor ); auto override_runtime_args_callback = [unary_reader_kernel_id, unary_writer_kernel_id]( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers + const void* operation, + Program &program, + const std::vector& input_tensors, + const std::vector>&, + const std::vector& output_tensors ) { - auto src_buffer = input_buffers.at(0); + auto src_buffer = input_tensors.at(0).buffer(); - auto dst_buffer = output_buffers.at(0); + auto dst_buffer = output_tensors.at(0).buffer(); CoreCoord core = {0, 0}; @@ -140,8 +140,7 @@ operation::ProgramWithCallbacks upsample_single_core(const Tensor &input, Tensor } }; - return {std::move(program), override_runtime_args_callback}; + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; } -} // namespace tt_metal -} // namespace tt +} // namespace ttnn::operations::upsample diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp index 55fa8b66b38..62007997d57 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp @@ -11,8 +11,12 @@ namespace ttnn::operations::upsample { ttnn::Tensor ExecuteUpSample::invoke(const ttnn::Tensor& input_tensor, std::variant scale_factor, - const std::optional& output_mem_config) { - MemoryConfig mem_config = output_mem_config.value_or(ttnn::DRAM_MEMORY_CONFIG); + std::string mode, + std::optional output_mem_config, + std::optional compute_kernel_config) { + MemoryConfig mem_config = output_mem_config.value_or(input_tensor.memory_config()); + ttnn::DeviceComputeKernelConfig config = compute_kernel_config.value_or( + ttnn::init_device_compute_kernel_config(input_tensor.device()->arch(), std::nullopt, MathFidelity::HiFi4)); int scale_h = 1; int scale_w = 1; std::visit( @@ -61,7 +65,7 @@ ttnn::Tensor ExecuteUpSample::invoke(const ttnn::Tensor& input_tensor, //return ttnn::upsample(input_tensor, scale_h, scale_w, mem_config); auto output_tensor = operation::run( - UpSample{scale_h, scale_w, mem_config}, + UpSample{scale_h, scale_w, mode, mem_config, config}, {input_tensor}).front(); return output_tensor; } diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.hpp b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.hpp index 6faa39583f3..376a0b20312 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.hpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.hpp @@ -16,7 +16,9 @@ struct ExecuteUpSample { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, std::variant scale_factor, - const std::optional& output_mem_config = std::nullopt); + std::string mode="nearest", + std::optional output_mem_config = std::nullopt, + std::optional compute_kernel_config = std::nullopt); }; } // upsample } // operations diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp index 110e6a648fa..2527177c199 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp @@ -33,7 +33,8 @@ void bind_upsample(py::module& module) { ttnn::upsample, doc, ttnn::pybind_arguments_t{ - py::arg("input_tensor"), py::arg("scale_factor"), py::kw_only(), py::arg("memory_config") = std::nullopt}); + py::arg("input_tensor"), py::arg("scale_factor"), py::kw_only(), py::arg("mode") = "nearest", + py::arg("memory_config") = std::nullopt, py::arg("compute_kernel_config") = std::nullopt}); } diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp index 3983e146d7e..127aeadf718 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp @@ -26,7 +26,13 @@ bool SlidingWindowConfig::has_parallel_config() const { Shape SlidingWindowConfig::get_output_shape() const { uint32_t output_h = (input_hw.first + 2 * pad_hw.first - window_hw.first - (dilation_hw.first - 1) * (window_hw.first - 1 )) / stride_hw.first + 1; uint32_t output_w = (input_hw.second + 2 * pad_hw.second - window_hw.second - (dilation_hw.second - 1) * (window_hw.second - 1 )) / stride_hw.second + 1; - log_debug(tt::LogOp, "SlidingWindowConfig::output_size: {} {} {}", batch_size, output_h, output_w); + if(is_bilinear){ + //for bilinear input and output should be same.. and kernel size is 2x2 + // we need neighboring width in the output tensor + output_h = input_hw.first; + output_w = input_hw.second; + } + log_debug(tt::LogOp, "output_size: {} {} {}", batch_size, output_h, output_w); return Shape( std::vector{batch_size, output_h, output_w, 0}); } @@ -90,6 +96,9 @@ std::vector> generate_shard_boundaries(c uint32_t dilated_window_w = config.window_hw.second + (config.dilation_hw.second - 1) * (config.window_hw.second - 1 ); uint32_t halo_with_pad_len = (dilated_window_h - 1) * padded_input_w + dilated_window_w - 1; + if(config.is_bilinear){ + halo_with_pad_len = (config.window_hw.first) * padded_input_w; + } uint32_t output_index_start = 0; for (uint32_t core = 0; core < num_cores; ++ core) { uint32_t output_index_end = std::min(output_index_start + output_shard_h, max_index) - 1; diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.hpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.hpp index 2be02a21fef..1a22e2cae66 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.hpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.hpp @@ -48,6 +48,7 @@ struct SlidingWindowConfig { CoreRangeSet core_range_set = std::set{CoreRange({0, 0}, {0, 0})}; // active cores bool snap_to_tile = false; + bool is_bilinear = false; std::string to_string() const; bool has_parallel_config() const; From 261c81e9d45ac71e681ac6a43cae055460f86756 Mon Sep 17 00:00:00 2001 From: Shwetank Singh Date: Wed, 18 Sep 2024 14:51:50 +0000 Subject: [PATCH 2/2] #5725: adding few validation for upsample. --- .../reader_bilinear_multi_core_sharded.cpp | 39 +++++++++++-------- ...ple_bilinear_program_factory_multicore.cpp | 17 +++++++- .../pool/upsample/device/upsample_op.cpp | 6 ++- .../operations/pool/upsample/upsample.cpp | 3 ++ 4 files changed, 46 insertions(+), 19 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp index 38b5e1f4513..89edffbc171 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/kernels/dataflow/reader_bilinear_multi_core_sharded.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -9,12 +9,18 @@ // Fill given four values into the memory starting at the given address. // WARNING: Use with caution as there's no memory protection. Make sure size is within limits -ALWI bool fill_four_val(uint32_t begin_addr, uint16_t val, uint16_t val1, uint16_t val2, uint16_t val3) { +ALWI void fill_four_val(uint32_t begin_addr, uint16_t val, uint16_t val1, uint16_t val2, uint16_t val3) { volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast(begin_addr); ptr[0] = (val | (val1 << 16)); ptr[1] = (val2 | (val3 << 16)); - return true; +} + +ALWI float uint32_to_float(uint32_t f) +{ + float ret; + std::memcpy(&ret, &f, sizeof(float)); + return ret; } @@ -34,24 +40,23 @@ void kernel_main() { constexpr uint32_t in_cb_id = get_compile_time_arg_val(0); constexpr uint32_t out_cb_id = tt::CB::c_in1; - constexpr uint32_t is_reader = get_compile_time_arg_val(2); - - uint32_t in_image_row_nbytes = in_w * stick_nbytes; - uint32_t out_image_row_nbytes = out_w * stick_nbytes; - uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2; - uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2; - uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core; - uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core; - uint32_t l1_read_addr = get_read_ptr(in_cb_id); //+ image_row_begin * in_image_row_nbytes; + //constexpr uint32_t is_reader = get_compile_time_arg_val(2); + constexpr uint32_t scale_h_inv_comp = get_compile_time_arg_val(3); + constexpr uint32_t scale_w_inv_comp = get_compile_time_arg_val(4); + constexpr uint32_t y_index_comp = get_compile_time_arg_val(5); + constexpr uint32_t x_index_compute_comp = get_compile_time_arg_val(6); + + uint32_t l1_read_addr = get_read_ptr(in_cb_id); constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4; // assuming shard begins with a new row. TODO: generalize? - float scale_h_inv = 1.0f / scale_h; - float scale_w_inv = 1.0f / scale_w; + float scale_h_inv = uint32_to_float(scale_h_inv_comp); + float scale_w_inv = uint32_to_float(scale_w_inv_comp); float x, y, x_index, y_index, dx, dy; - y_index = (float)(0.5f) * (float)scale_h_inv + 0.5f; + y_index = uint32_to_float(y_index_comp); + float x_index_compute = uint32_to_float(x_index_compute_comp); for (uint32_t image_row = 0 ; image_row < in_image_rows_per_core * scale_h; ++image_row){ - x_index = (float)(0.5f) * (float)scale_w_inv -0.5f; + x_index = x_index_compute; for(uint32_t j=0; j < in_w * scale_w; j++){ cb_reserve_back(out_cb_id, 4); cb_reserve_back(in_scalar_cb_id, 1); @@ -64,7 +69,7 @@ void kernel_main() { uint32_t x1 = int(x); uint32_t y1 = int(y); uint32_t x2 = min(x1 + 1, in_w-1); - uint32_t y2 = y1 + 1; //, in_image_rows_per_core - 1); + uint32_t y2 = y1 + 1; if(is_last_row){ y2 = min(y2, in_image_rows_per_core); //if last row, y2 should be in_image_rows_per_core } diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp index 44c8fcfe501..ea7d7672b40 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -175,10 +175,25 @@ operation::ProgramWithCallbacks bilinear_multi_core(const Tensor &input, Tensor& log_debug(LogOp, "input_nsticks_per_core: {}, output_nsticks_per_core: {}", input_nsticks_per_core, output_nsticks_per_core); // Kernels + //computation needed for the bilinear kernel. Passing them as an argument. + float scale_h_inv = 1.0f / (float)scale_factor_h; + float scale_w_inv = 1.0f / (float)scale_factor_w; + float y_index = (float)(0.5f) * (float)scale_h_inv + 0.5f; + float x_index_compute = (float)(0.5f) * (float)scale_w_inv - 0.5f; + + uint32_t scale_h_inv_u32 = *reinterpret_cast(&scale_h_inv); + uint32_t scale_w_inv_u32 = *reinterpret_cast(&scale_w_inv); + uint32_t y_index_u32 = *reinterpret_cast(&y_index); + uint32_t x_index_compute_u32 = *reinterpret_cast(&x_index_compute); + std::vector reader_compile_time_args = { in_cb_id, out_cb_id, false, + scale_h_inv_u32, + scale_w_inv_u32, + y_index_u32, + x_index_compute_u32, }; string writer_kernel_fname, reader_kernel_fname, compute_kernel_fname; diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp index 7e450c8c47c..b4444f8f497 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp @@ -25,7 +25,11 @@ void UpSample::validate(const std::vector &input_tensors) const { TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16, "Input tensor data type should be BFLOAT16"); if (input_tensor_a.memory_config().is_sharded()) { TT_FATAL(input_tensor_a.memory_config().memory_layout == output_mem_config_.memory_layout, "Input tensor memory layout should be same as output tensor memory layout"); - TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED, "Input tensor memory layout should be HEIGHT or BLOCK sharded"); + if(mode_ == "nearest") + TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED, "Input tensor memory layout should be HEIGHT or BLOCK sharded"); + else if(mode_ == "bilinear") + TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Input tensor memory layout should be HEIGHT sharded"); + TT_FATAL(mode_ == "bilinear" || mode_ == "nearest", "Upsample only supports bilinear or nearest mode"); TT_FATAL(input_tensor_a.buffer()->buffer_type() == tt::tt_metal::BufferType::L1, "Input buffer should be sharded in L1"); } } diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp index 62007997d57..4fad1223ddb 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/upsample.cpp @@ -17,6 +17,9 @@ ttnn::Tensor ExecuteUpSample::invoke(const ttnn::Tensor& input_tensor, MemoryConfig mem_config = output_mem_config.value_or(input_tensor.memory_config()); ttnn::DeviceComputeKernelConfig config = compute_kernel_config.value_or( ttnn::init_device_compute_kernel_config(input_tensor.device()->arch(), std::nullopt, MathFidelity::HiFi4)); + if(mode.empty()) { + mode = "nearest"; + } int scale_h = 1; int scale_w = 1; std::visit(