diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 3339eea5189..e698e93bfd9 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -277,6 +277,10 @@ Primary Operations .. autofunction:: tt_lib.operations.primary.moreh_norm_backward +.. autofunction:: tt_lib.operations.primary.moreh_nll_loss_unreduced + +.. autofunction:: tt_lib.operations.primary.moreh_nll_loss_unreduced_backward + Enums ===== diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss_unreduced.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss_unreduced.py new file mode 100644 index 00000000000..a86781746c2 --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss_unreduced.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import tt_lib as ttl +import pytest +from models.utility_functions import comp_allclose_and_pcc, is_wormhole_b0 +from loguru import logger + +from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( + get_compute_kernel_options, + compute_kernel_options, + compute_kernel_ids, + to_cpu, + to_npu, +) + + +def get_torch_tensors(shape): + C = shape[1] + target_shape = shape[:1] + shape[2:] + + cpu_dtype = torch.float32 + cpu_index_dtype = torch.long + + torch_input = torch.rand(shape, dtype=cpu_dtype).requires_grad_() + torch_target = torch.randint(0, C, target_shape, dtype=cpu_index_dtype) + torch_weight = torch.rand(C, dtype=cpu_dtype) + torch_output = torch.empty(target_shape, dtype=cpu_dtype) + + return torch_input, torch_target, torch_weight, torch_output + + +def get_tt_tensors(torch_input, torch_target, torch_weight, torch_output, device): + npu_index_dtype = ttl.tensor.DataType.INT32 + + tt_input = to_npu(torch_input, device) + tt_target = to_npu(torch_target, device, npu_dtype=npu_index_dtype) + tt_weight = to_npu(torch_weight, device) + tt_output = to_npu(torch_output, device) + + return tt_input, tt_target, tt_weight, tt_output + + +def get_tt_backward_tensors(torch_target, torch_weight, torch_output_grad, torch_input_grad, device): + npu_index_dtype = ttl.tensor.DataType.INT32 + + tt_target = to_npu(torch_target, device, npu_dtype=npu_index_dtype) + tt_weight = to_npu(torch_weight, device) + tt_output_grad = to_npu(torch_output_grad, device) + tt_input_grad = to_npu(torch_input_grad, device) + + return tt_target, tt_weight, tt_output_grad, tt_input_grad + + +def run_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, device, compute_kernel_options=None): + compute_kernel_config = get_compute_kernel_options(compute_kernel_options) + + (torch_input, torch_target, torch_weight, torch_output) = get_torch_tensors(shape) + + if none_weight: + torch_weight = None + + nll_loss = torch.nn.NLLLoss(weight=torch_weight, ignore_index=ignore_index, reduction="none") + torch_loss = nll_loss(torch_input, torch_target) + + (tt_input, tt_target, tt_weight, tt_output) = get_tt_tensors( + torch_input, torch_target, torch_weight, torch_output, device + ) + + tt_loss = ttl.operations.primary.moreh_nll_loss_unreduced( + tt_input, + tt_target, + tt_weight, + tt_output, + ignore_index, + compute_kernel_config=compute_kernel_config, + ) + + tt_loss_to_cpu = to_cpu(tt_loss, torch_target.shape) + + rtol = atol = 0.05 + passing, out = comp_allclose_and_pcc(torch_loss, tt_loss_to_cpu, pcc=0.999, rtol=rtol, atol=atol) + logger.debug(f"Out passing (param)={passing}") + logger.debug(f"Output pcc={out}") + + assert passing + + +def run_moreh_nll_loss_unreduced_backward(shape, ignore_index, none_weight, device, compute_kernel_options=None): + compute_kernel_config = get_compute_kernel_options(compute_kernel_options) + + # run torch + (torch_input, torch_target, torch_weight, _) = get_torch_tensors(shape) + if none_weight: + torch_weight = None + + nll_loss = torch.nn.NLLLoss(weight=torch_weight, ignore_index=ignore_index, reduction="none") + torch_loss = nll_loss(torch_input, torch_target) + + output_grad = torch.randn_like(torch_loss) + torch_loss.backward(output_grad) + + # run tt + (tt_target, tt_weight, tt_output_grad, tt_input_grad) = get_tt_backward_tensors( + torch_target, torch_weight, output_grad, torch_input.grad, device + ) + + tt_input_grad = ttl.operations.primary.moreh_nll_loss_unreduced_backward( + tt_target, + tt_weight, + tt_output_grad, + tt_input_grad, + ignore_index, + compute_kernel_config=compute_kernel_config, + ) + tt_input_grad_to_cpu = to_cpu(tt_input_grad, torch_input.grad.shape) + + rtol = atol = 0.05 + passing, out = comp_allclose_and_pcc(torch_input.grad, tt_input_grad_to_cpu, pcc=0.999, rtol=rtol, atol=atol) + + logger.debug(f"Out passing (param)={passing}") + logger.debug(f"Output pcc={out}") + + assert passing + + +@pytest.mark.parametrize( + "shape", + [ + (5, 10), + (500, 100), + (4, 3, 2, 4, 50, 70), + ], +) +@pytest.mark.parametrize("ignore_index", [1]) +@pytest.mark.parametrize("none_weight", [True, False]) +@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) +def test_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, compute_kernel_options, device, use_program_cache): + torch.manual_seed(0) + + run_moreh_nll_loss_unreduced( + shape, ignore_index, none_weight, device, compute_kernel_options=compute_kernel_options + ) + + +@pytest.mark.parametrize( + "shape", + [ + (32, 32), + (400, 300), + (20, 300, 320), + (5, 2, 5, 40, 70), + ], +) +@pytest.mark.parametrize("ignore_index", [1]) +@pytest.mark.parametrize("none_weight", [True, False]) +@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) +def test_moreh_nll_loss_unreduced_backward( + shape, ignore_index, none_weight, compute_kernel_options, device, use_program_cache +): + torch.manual_seed(0) + + run_moreh_nll_loss_unreduced_backward( + shape, ignore_index, none_weight, device, compute_kernel_options=compute_kernel_options + ) + + +@pytest.mark.parametrize( + "shape", + [ + (5, 10), + (5, 10, 10), + (5, 10, 10, 20), + ], +) +@pytest.mark.parametrize("none_weight", [True, False]) +def test_moreh_nll_loss_unreduced_callback(shape, none_weight, device, use_program_cache): + torch.manual_seed(0) + + ignore_index = 1 + + for _ in range(2): + run_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, device) + + +@pytest.mark.parametrize( + "shape", + [ + (2, 3), + (2, 3, 4), + (2, 3, 5, 4), + ], +) +@pytest.mark.parametrize("none_weight", [True, False]) +def test_moreh_nll_loss_unreduced_backward_test_callback(shape, none_weight, device, use_program_cache): + torch.manual_seed(0) + + ignore_index = 0 + + for _ in range(2): + run_moreh_nll_loss_unreduced_backward(shape, ignore_index, none_weight, device) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_utils.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_utils.py index 1e5ded1af4e..18379923556 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_utils.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_utils.py @@ -4,6 +4,7 @@ import tt_lib as ttl from models.utility_functions import is_wormhole_b0 +import copy compute_kernel_options = [ False, # for grayskull @@ -33,3 +34,46 @@ def get_compute_kernel_options(compute_kernel_options): math_approx_mode=True, ) return compute_kernel_config + + +def to_cpu(npu_tensor, shape, *, cpu_layout=ttl.tensor.Layout.ROW_MAJOR): + if npu_tensor is None: + return None + + shape = list(shape) + + unpad_shape = copy.copy(shape) + + if shape == []: + unpad_shape = [1, 1] + + if len(shape) == 1: + unpad_shape = [1] + shape + + cpu_tensor = npu_tensor.cpu().to(cpu_layout).unpad_from_tile(unpad_shape).to_torch().reshape(shape) + + return cpu_tensor + + +def to_npu( + cpu_tensor, + device, + *, + npu_layout=ttl.tensor.Layout.TILE, + npu_dtype=ttl.tensor.DataType.BFLOAT16, + shape=None, +): + if cpu_tensor is None: + return None + + if shape is not None: + cpu_tensor = cpu_tensor.view(shape) + + if len(cpu_tensor.shape) == 1: + cpu_tensor = cpu_tensor.reshape([1, len(cpu_tensor)]) + + if len(cpu_tensor.shape) == 0: + cpu_tensor = cpu_tensor.reshape([1, 1]) + + npu_tensor = ttl.tensor.Tensor(cpu_tensor, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + return npu_tensor diff --git a/tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp b/tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp index 9976e6d6664..d0a5b4e26ba 100644 --- a/tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp +++ b/tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp @@ -21,6 +21,15 @@ static inline float bfloat16_to_float(uint16_t bfloat_val) { return f; } +static inline uint16_t float_to_bfloat16(float val) { + union { + float f; + uint32_t u; + } ret; + ret.f = val; + return uint16_t(ret.u >> 16); +} + #if defined(FP32_DEST_ACC_EN) using FP32_DEST_ACC_FTYPE = float; FORCE_INLINE FP32_DEST_ACC_FTYPE fp32_dest_acc_cast(uint16_t val) { return bfloat16_to_float(val); } @@ -838,3 +847,69 @@ void get_noc_offset(uint32_t h, uint32_t w, uint32_t element_size, uint32_t &noc noc_offset = noc_offset_alilgn_32; } + +template +volatile tt_l1_ptr T* get_read_ptr(uint32_t cb_id) { + + auto l1_write_addr = get_read_ptr(cb_id); + auto l1_ptr = reinterpret_cast(l1_write_addr); + return l1_ptr; +} + +template +volatile tt_l1_ptr T* get_write_ptr(uint32_t cb_id) { + + auto l1_write_addr = get_write_ptr(cb_id); + auto l1_ptr = reinterpret_cast(l1_write_addr); + return l1_ptr; +} + +// It reads values from one tile. +template +void read_tile(uint32_t cb_id, T addrgen, uint32_t noc_id, uint32_t size = 0, uint32_t offset = 0, bool do_reserve = true, bool do_push_back = true) { + + constexpr uint32_t onetile = 1; + + if (do_reserve) cb_reserve_back(cb_id, onetile); + + // If the size is 0, it reads one tile. + if (size == 0){ + size = get_tile_size(cb_id); + } + + auto l1_write_addr = get_write_ptr(cb_id); + auto noc_addr = get_noc_addr(noc_id, addrgen, offset); + noc_async_read(noc_addr, l1_write_addr, size); + noc_async_read_barrier(); + + if (do_push_back) cb_push_back(cb_id, onetile); +} + + +// It reads values from a tilized tensor with shape (1, W). +template +void read_line(uint32_t cb_id, T addrgen, uint32_t num_tiles, bool do_reserve = true, bool do_push_back = true) { + + if (do_reserve) cb_reserve_back(cb_id, num_tiles); + + auto tile_bytes = get_tile_size(cb_id); + auto element_size = tile_bytes / 1024; + auto noc_read_size = FACE_WIDTH * element_size; + + uint32_t l1_write_addr = get_write_ptr(cb_id); + + for (uint32_t i = 0; i < num_tiles * 2; ++i) { + uint32_t noc_id = i / 2; + uint32_t noc_offset = 0; + if (noc_id * 2 != i) { + noc_offset += 256 * element_size; + } + auto src_noc_addr = get_noc_addr(noc_id, addrgen, noc_offset); + noc_async_read(src_noc_addr, l1_write_addr, noc_read_size); + noc_async_read_barrier(); + + l1_write_addr += noc_read_size; + } + + if (do_push_back) cb_push_back(cb_id, num_tiles); +} diff --git a/tt_eager/tt_dnn/op_library/CMakeLists.txt b/tt_eager/tt_dnn/op_library/CMakeLists.txt index 0f19ee7a9c0..5966b57f29c 100644 --- a/tt_eager/tt_dnn/op_library/CMakeLists.txt +++ b/tt_eager/tt_dnn/op_library/CMakeLists.txt @@ -86,6 +86,9 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/moreh_softmax_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_w_small/softmax_w_small.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_h_small/softmax_h_small.cpp diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.cpp new file mode 100644 index 00000000000..417f546f466 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.cpp @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_op.hpp" +#include "tt_dnn/op_library/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.hpp" + +#include "tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; +using namespace std; +using namespace tt::tt_metal; + +namespace tt { +namespace operations { +namespace primary { + + +Tensor moreh_nll_loss_unreduced( + const Tensor& input_tensor, + const Tensor& target_tensor, + const std::optional weight_tensor, + const std::optional output_tensor, + const int32_t ignore_index, + const MemoryConfig& memory_config, + std::optional compute_kernel_config) { + const Tensor& result = moreh_nll_loss_step2( + input_tensor, + target_tensor, + weight_tensor, + std::nullopt, + ignore_index, + "sum", + memory_config, + compute_kernel_config); + + return result; +} + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.hpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.hpp new file mode 100644 index 00000000000..3650c6c436c --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.hpp @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "tt_dnn/op_library/compute_kernel_config.hpp" +#include "tt_dnn/op_library/operation.hpp" +#include "tt_eager/tensor/tensor.hpp" + +namespace tt { +namespace operations { +namespace primary { + +using namespace tt_metal; + +Tensor moreh_nll_loss_unreduced( + const Tensor &input_tensor, + const Tensor &target_tensor, + const std::optional weight_tensor, + const std::optional output_tensor, + const int32_t ignore_index, + const MemoryConfig &memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional compute_kernel_config = std::nullopt); + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_2d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_2d.cpp new file mode 100644 index 00000000000..020396b19b1 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_2d.cpp @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp" +#include "dprint.h" + +void kernel_main() { + uint32_t i = 0; + auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); + auto weight_addr = get_arg_val(i++); + auto ignore_index = static_cast(get_arg_val(i++)); + auto num_tiles_per_core = get_arg_val(i++); + auto start_id = get_arg_val(i++); + auto Nt = get_arg_val(i++); + auto C = get_arg_val(i++); + auto Ct = get_arg_val(i++); + + constexpr uint32_t cb_target = tt::CB::c_in0; + constexpr uint32_t cb_output_grad = tt::CB::c_in1; + constexpr uint32_t cb_weight = tt::CB::c_in2; + + constexpr uint32_t cb_input_grad = tt::CB::c_out0; + + // ublocks size defined in tiles + const uint32_t target_tile_bytes = get_tile_size(cb_target); + + const uint32_t weight_tile_bytes = get_tile_size(cb_weight); + const DataFormat weight_data_format = get_dataformat(cb_weight); + + const uint32_t output_grad_tile_bytes = get_tile_size(cb_output_grad); + const DataFormat output_grad_data_format = get_dataformat(cb_output_grad); + + constexpr bool target_is_dram = get_compile_time_arg_val(0) == 1; + constexpr bool output_grad_is_dram = get_compile_time_arg_val(1) == 1; + constexpr bool weight_is_dram = get_compile_time_arg_val(2) == 1; + + const InterleavedAddrGen addrg_target = { + .bank_base_address = target_addr, .page_size = target_tile_bytes}; + constexpr uint32_t onetile = 1; + +#if defined(WEIGHT) + const InterleavedAddrGen addrg_weight = { + .bank_base_address = weight_addr, + .page_size = weight_tile_bytes, + }; + + // weight: (1, C) + read_line(cb_weight, addrg_weight, Ct); + + cb_wait_front(cb_weight, Ct); + auto weight_l1_ptr = get_read_ptr(cb_weight); +#endif + + const InterleavedAddrGen addrg_output_grad = { + .bank_base_address = output_grad_addr, + .page_size = output_grad_tile_bytes, + }; + + read_line(cb_output_grad, addrg_output_grad, Nt); + + cb_wait_front(cb_output_grad, Nt); + + auto zero = float_to_bfloat16(0.0f); + + uint32_t end_id = start_id + num_tiles_per_core; + for (uint32_t i = start_id; i < end_id; ++i) { + uint32_t nt = i / Ct; + uint32_t ct = i % Ct; + + // target: (1, N) + auto target_noc_id = nt; + read_tile(cb_target, addrg_target, target_noc_id); + + cb_reserve_back(cb_input_grad, onetile); + cb_wait_front(cb_target, onetile); + + auto input_grad_l1_ptr = get_write_ptr(cb_input_grad); + auto target_l1_ptr = get_read_ptr(cb_target); + auto output_grad_l1_ptr = get_read_ptr(cb_output_grad); + + for (uint32_t h = 0; h < TILE_HEIGHT; h++) { + for (uint32_t w = 0; w < TILE_WIDTH; w++) { + uint32_t n = nt * TILE_HEIGHT + h; + uint32_t c = ct * TILE_WIDTH + w; + + uint32_t target_tilized_idx = get_tilized_idx(0, h); // target(0, n) + int32_t target_val = target_l1_ptr[target_tilized_idx]; + + uint32_t input_grad_idx = get_tilized_idx(h, w); // input_grad(n, c) + + uint16_t input_grad_val; + + if (target_val != ignore_index && target_val == static_cast(c)) { + float output_grad_val = bfloat16_to_float(output_grad_l1_ptr[n]); +#if defined(WEIGHT) + float weight_val = bfloat16_to_float(weight_l1_ptr[target_val]); + + input_grad_val = float_to_bfloat16(-output_grad_val * weight_val); +#else + input_grad_val = float_to_bfloat16(-output_grad_val); +#endif + } else { + input_grad_val = zero; + } + input_grad_l1_ptr[input_grad_idx] = input_grad_val; + } + } + + cb_push_back(cb_input_grad, onetile); + + cb_pop_front(cb_target, onetile); + } +} diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_3d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_3d.cpp new file mode 100644 index 00000000000..251cd19a083 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_3d.cpp @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp" + +void kernel_main() { + uint32_t i = 0; + auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); + auto weight_addr = get_arg_val(i++); + auto ignore_index = static_cast(get_arg_val(i++)); + auto num_tiles_per_core = get_arg_val(i++); + auto start_id = get_arg_val(i++); + auto C = get_arg_val(i++); + auto Ct = get_arg_val(i++); + auto Wt = get_arg_val(i++); + + constexpr uint32_t cb_target = tt::CB::c_in0; + constexpr uint32_t cb_output_grad = tt::CB::c_in1; + constexpr uint32_t cb_weight = tt::CB::c_in2; + + constexpr uint32_t cb_input_grad = tt::CB::c_out0; + + // ublocks size defined in tiles + const uint32_t target_tile_bytes = get_tile_size(cb_target); + + const uint32_t weight_tile_bytes = get_tile_size(cb_weight); + const DataFormat weight_data_format = get_dataformat(cb_weight); + + const uint32_t output_grad_tile_bytes = get_tile_size(cb_output_grad); + const DataFormat output_grad_data_format = get_dataformat(cb_output_grad); + + constexpr bool target_is_dram = get_compile_time_arg_val(0) == 1; + constexpr bool output_grad_is_dram = get_compile_time_arg_val(1) == 1; + constexpr bool weight_is_dram = get_compile_time_arg_val(2) == 1; + + const InterleavedAddrGen addrg_target = { + .bank_base_address = target_addr, .page_size = target_tile_bytes}; + constexpr uint32_t onetile = 1; + +#if defined(WEIGHT) + const InterleavedAddrGen addrg_weight = { + .bank_base_address = weight_addr, + .page_size = weight_tile_bytes, + }; + + // weight: (1, C) + read_line(cb_weight, addrg_weight, Ct); + + cb_wait_front(cb_weight, Ct); + auto weight_l1_ptr = get_read_ptr(cb_weight); +#endif + + const InterleavedAddrGenFast addrg_output_grad = { + .bank_base_address = output_grad_addr, + .page_size = output_grad_tile_bytes, + .data_format = output_grad_data_format}; + + auto zero = float_to_bfloat16(0.0f); + + uint32_t end_id = start_id + num_tiles_per_core; + for (uint32_t i = start_id; i < end_id; ++i) { + uint32_t wt = i % Wt; + uint32_t nct = i / Wt; + uint32_t n = nct / Ct; + uint32_t nt = n / TILE_HEIGHT; + uint32_t ct = nct % Ct; + + // target: (N, W) + auto target_noc_id = nt * Wt + wt; + read_tile(cb_target, addrg_target, target_noc_id); + + // output_grad: (N, W) + auto output_grad_noc_id = nt * Wt + wt; + read_tile(cb_output_grad, addrg_output_grad, output_grad_noc_id); + + cb_reserve_back(cb_input_grad, onetile); + cb_wait_front(cb_target, onetile); + cb_wait_front(cb_output_grad, onetile); + + auto input_grad_l1_ptr = get_write_ptr(cb_input_grad); + auto target_l1_ptr = get_read_ptr(cb_target); + auto output_grad_l1_ptr = get_read_ptr(cb_output_grad); + + for (uint32_t h = 0; h < TILE_HEIGHT; h++) { + for (uint32_t w = 0; w < TILE_WIDTH; w++) { + uint32_t nw_tilized_idx = get_tilized_idx(n % TILE_HEIGHT, w); // target(n, w) + int32_t target_val = target_l1_ptr[nw_tilized_idx]; + + uint32_t c = ct * TILE_HEIGHT + h; + uint32_t input_grad_idx = get_tilized_idx(h, w); // input_grad(c, w) + + uint16_t input_grad_val; + + if (target_val != ignore_index && target_val == static_cast(c)) { + float output_grad_val = bfloat16_to_float(output_grad_l1_ptr[nw_tilized_idx]); + +#if defined(WEIGHT) + float weight_val = bfloat16_to_float(weight_l1_ptr[target_val]); + + input_grad_val = float_to_bfloat16(-output_grad_val * weight_val); +#else + input_grad_val = float_to_bfloat16(-output_grad_val); +#endif + } else { + input_grad_val = zero; + } + input_grad_l1_ptr[input_grad_idx] = input_grad_val; + } + } + + cb_push_back(cb_input_grad, onetile); + + cb_pop_front(cb_target, onetile); + + cb_pop_front(cb_output_grad, onetile); + } +} diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_4d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_4d.cpp new file mode 100644 index 00000000000..d0e00792132 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/reader_moreh_nll_loss_unreduced_backward_4d.cpp @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp" + +void kernel_main() { + uint32_t i = 0; + auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); + auto weight_addr = get_arg_val(i++); + auto ignore_index = static_cast(get_arg_val(i++)); + auto num_tiles_per_core = get_arg_val(i++); + auto start_id = get_arg_val(i++); + auto num_inner_tile = get_arg_val(i++); + auto C = get_arg_val(i++); + auto Ct = get_arg_val(i++); + + constexpr uint32_t cb_target = tt::CB::c_in0; + constexpr uint32_t cb_output_grad = tt::CB::c_in1; + constexpr uint32_t cb_weight = tt::CB::c_in2; + + constexpr uint32_t cb_input_grad = tt::CB::c_out0; + + // ublocks size defined in tiles + const uint32_t target_tile_bytes = get_tile_size(cb_target); + + const uint32_t weight_tile_bytes = get_tile_size(cb_weight); + const DataFormat weight_data_format = get_dataformat(cb_weight); + + const uint32_t output_grad_tile_bytes = get_tile_size(cb_output_grad); + const DataFormat output_grad_data_format = get_dataformat(cb_output_grad); + + constexpr bool target_is_dram = get_compile_time_arg_val(0) == 1; + constexpr bool output_grad_is_dram = get_compile_time_arg_val(1) == 1; + constexpr bool weight_is_dram = get_compile_time_arg_val(2) == 1; + + const InterleavedAddrGen addrg_target = { + .bank_base_address = target_addr, .page_size = target_tile_bytes}; + const InterleavedAddrGenFast addrg_output_grad = { + .bank_base_address = output_grad_addr, + .page_size = output_grad_tile_bytes, + .data_format = output_grad_data_format}; + constexpr uint32_t onetile = 1; + +#if defined(WEIGHT) + const InterleavedAddrGen addrg_weight = { + .bank_base_address = weight_addr, + .page_size = weight_tile_bytes, + }; + + // weight: (1, C) + read_line(cb_weight, addrg_weight, Ct); + + cb_wait_front(cb_weight, Ct); + auto weight_l1_ptr = get_read_ptr(cb_weight); +#endif + + auto zero = float_to_bfloat16(0.0f); + + uint32_t end_id = start_id + num_tiles_per_core; + for (uint32_t i = start_id; i < end_id; ++i) { + uint32_t inner = i % num_inner_tile; + uint32_t nc = i / num_inner_tile; + uint32_t n = nc / C; + uint32_t c = nc % C; + + // target: (N, H, W) + auto target_noc_id = n * num_inner_tile + inner; + read_tile(cb_target, addrg_target, target_noc_id); + + // output_grad: (N, H, W) + auto output_grad_noc_id = n * num_inner_tile + inner; + read_tile(cb_output_grad, addrg_output_grad, output_grad_noc_id); + + cb_reserve_back(cb_input_grad, onetile); + cb_wait_front(cb_target, onetile); + cb_wait_front(cb_output_grad, onetile); + + auto input_grad_l1_ptr = get_write_ptr(cb_input_grad); + auto target_l1_ptr = get_read_ptr(cb_target); + auto output_grad_l1_ptr = get_read_ptr(cb_output_grad); + + for (uint32_t h = 0; h < TILE_HEIGHT; h++) { + for (uint32_t w = 0; w < TILE_WIDTH; w++) { + uint32_t idx = h * TILE_WIDTH + w; // target and input_grad idx + + int32_t target_val = target_l1_ptr[idx]; + + uint16_t input_grad_val; + + if (target_val != ignore_index && target_val == static_cast(c)) { + float output_grad_val = bfloat16_to_float(output_grad_l1_ptr[idx]); + +#if defined(WEIGHT) + float weight_val = bfloat16_to_float(weight_l1_ptr[target_val]); + + input_grad_val = float_to_bfloat16(-output_grad_val * weight_val); +#else + input_grad_val = float_to_bfloat16(-output_grad_val); +#endif + } else { + input_grad_val = zero; + } + + input_grad_l1_ptr[idx] = input_grad_val; + } + } + + cb_push_back(cb_input_grad, onetile); + + cb_pop_front(cb_target, onetile); + cb_pop_front(cb_output_grad, onetile); + } +} diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/writer_moreh_nll_loss_unreduced_backward.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/writer_moreh_nll_loss_unreduced_backward.cpp new file mode 100644 index 00000000000..0669f261dcd --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/writer_moreh_nll_loss_unreduced_backward.cpp @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t i = 0; + auto input_grad_addr = get_arg_val(i++); + auto num_tiles_per_core = get_arg_val(i++); + auto start_id = get_arg_val(i++); + + constexpr uint32_t cb_input_grad = tt::CB::c_out0; + + const uint32_t input_grad_tile_bytes = get_tile_size(cb_input_grad); + const auto input_grad_data_format = get_dataformat(cb_input_grad); + + constexpr bool input_grad_is_dram = get_compile_time_arg_val(0) == 1; + + const InterleavedAddrGenFast input_grad_addrg = { + .bank_base_address = input_grad_addr, + .page_size = input_grad_tile_bytes, + .data_format = input_grad_data_format}; + + constexpr uint32_t onetile = 1; + + uint32_t end_id = start_id + num_tiles_per_core; + for (uint32_t i = start_id; i < end_id; ++ i) { + cb_wait_front(cb_input_grad, onetile); + uint32_t input_grad_l1_write_addr = get_read_ptr(cb_input_grad); + noc_async_write_tile(i, input_grad_addrg, input_grad_l1_write_addr); + noc_async_write_barrier(); + cb_pop_front(cb_input_grad, onetile); + } +} diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward.cpp new file mode 100644 index 00000000000..3b1b2cf2fd7 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward.cpp @@ -0,0 +1,473 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; +using namespace tt::tt_metal; + +namespace tt { +namespace operations { +namespace primary { + +namespace { + +operation::ProgramWithCallbacks moreh_nll_loss_unreduced_backward_impl_2d( + const Tensor &target, + const std::optional weight, + const Tensor &output_grad, + const Tensor &input_grad, + const int32_t ignore_index, + const CoreRange core_range, + const DeviceComputeKernelConfig compute_kernel_config) { + // split work + + // input_grad: (N, C) + auto input_grad_shape = input_grad.get_legacy_shape(); + auto N = input_grad_shape[0]; + auto channel_size = input_grad_shape[1]; + + auto W = input_grad_shape[-1]; + auto Wt = W / TILE_WIDTH; + + const bool weight_has_value = weight.has_value(); + + uint32_t core_w = core_range.end.x - core_range.start.x + 1; + uint32_t core_h = core_range.end.y - core_range.start.y + 1; + + uint32_t units_to_divide = input_grad.volume() / TILE_HEIGHT / TILE_WIDTH; + + auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = + split_work_to_cores(core_range, units_to_divide); + + auto arch = input_grad.device()->arch(); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + get_compute_kernel_config_args(arch, compute_kernel_config); + + Program program = Program(); + + // create circular buffers + tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(input_grad.get_dtype()); + + auto Ct = div_up(channel_size, TILE_WIDTH); + auto Nt = div_up(N, TILE_WIDTH); + CreateCircularBuffer( + program, + all_cores, + data_format, + { + {CB::c_in0, 1, tt::DataFormat::Int32}, // target + {CB::c_in1, Nt}, // output_grad + {CB::c_in2, static_cast(weight_has_value ? Ct : 0)}, // weight + {CB::c_out0, 1}, // input_grad + }); + + // create read/wrtie kernel + const std::vector reader_compile_time_args{ + static_cast(is_dram(target)), + static_cast(is_dram(output_grad)), + static_cast(is_dram(weight))}; + + const std::vector writer_compile_time_args{static_cast(is_dram(input_grad))}; + + std::map reader_defines; + std::map writer_defines; + + if (weight_has_value) { + reader_defines["WEIGHT"] = 1; + } + + if (fp32_dest_acc_en) { + reader_defines["FP32_DEST_ACC_EN"] = 1; + } + + auto reader_kernel_id = CreateReadKernel( + program, + "tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/" + "reader_moreh_nll_loss_unreduced_backward_2d.cpp", + all_cores, + reader_compile_time_args, + reader_defines); + auto writer_kernel_id = CreateWriteKernel( + program, + "tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/" + "writer_moreh_nll_loss_unreduced_backward.cpp", + all_cores, + writer_compile_time_args, + writer_defines); + + const auto target_addr = target.buffer()->address(); + const auto weight_addr = weight_has_value ? weight.value().buffer()->address() : 0; + const auto output_grad_addr = output_grad.buffer()->address(); + const auto input_grad_addr = input_grad.buffer()->address(); + + // Set Runtime Args + auto core_x_offset = core_range.start.x; + auto core_y_offset = core_range.start.y; + for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { + CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; + uint32_t units_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + units_per_core = units_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + units_per_core = units_per_core_group_2; + } else { + TT_THROW("Core not in specified core ranges"); + } + + std::vector reader_args = { + target_addr, + output_grad_addr, + weight_addr, + static_cast(ignore_index), + units_per_core, + tile_offset, + Nt, + channel_size, + Ct, + }; + + std::vector writer_args = {input_grad_addr, units_per_core, tile_offset}; + + SetRuntimeArgs(program, reader_kernel_id, core, reader_args); + SetRuntimeArgs(program, writer_kernel_id, core, writer_args); + + tile_offset += units_per_core; + } + + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; +} + + +operation::ProgramWithCallbacks moreh_nll_loss_unreduced_backward_impl_3d( + const Tensor &target, + const std::optional weight, + const Tensor &output_grad, + const Tensor &input_grad, + const int32_t ignore_index, + const CoreRange core_range, + const DeviceComputeKernelConfig compute_kernel_config) { + // split work + + // input_grad: (N, C, W) + auto input_grad_shape = input_grad.get_legacy_shape(); + auto N = input_grad_shape[0]; + auto channel_size = input_grad_shape[1]; + + auto W = input_grad_shape[-1]; + auto Ct = channel_size / TILE_HEIGHT; + auto Wt = W / TILE_WIDTH; + + auto target_shape = target.get_legacy_shape(); + auto num_inner_tile = target_shape[-1] / TILE_WIDTH; + + const bool weight_has_value = weight.has_value(); + + uint32_t core_w = core_range.end.x - core_range.start.x + 1; + uint32_t core_h = core_range.end.y - core_range.start.y + 1; + + uint32_t units_to_divide = input_grad.volume() / TILE_HEIGHT / TILE_WIDTH; + + auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = + split_work_to_cores(core_range, units_to_divide); + + auto arch = input_grad.device()->arch(); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + get_compute_kernel_config_args(arch, compute_kernel_config); + + Program program = Program(); + + // create circular buffers + tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(input_grad.get_dtype()); + + CreateCircularBuffer( + program, + all_cores, + data_format, + { + {CB::c_in0, 1, tt::DataFormat::Int32}, // target + {CB::c_in1, 1}, // output_grad + {CB::c_in2, static_cast(weight_has_value ? Ct : 0)}, // weight + {CB::c_out0, 1}, // input_grad + }); + + // create read/wrtie kernel + const std::vector reader_compile_time_args{ + static_cast(is_dram(target)), + static_cast(is_dram(output_grad)), + static_cast(is_dram(weight))}; + + const std::vector writer_compile_time_args{static_cast(is_dram(input_grad))}; + + std::map reader_defines; + std::map writer_defines; + + if (weight_has_value) { + reader_defines["WEIGHT"] = 1; + } + + if (fp32_dest_acc_en) { + reader_defines["FP32_DEST_ACC_EN"] = 1; + } + + auto reader_kernel_id = CreateReadKernel( + program, + "tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/" + "reader_moreh_nll_loss_unreduced_backward_3d.cpp", + all_cores, + reader_compile_time_args, + reader_defines); + auto writer_kernel_id = CreateWriteKernel( + program, + "tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/" + "writer_moreh_nll_loss_unreduced_backward.cpp", + all_cores, + writer_compile_time_args, + writer_defines); + + const auto target_addr = target.buffer()->address(); + const auto output_grad_addr = output_grad.buffer()->address(); + const auto weight_addr = weight_has_value ? weight.value().buffer()->address() : 0; + const auto input_grad_addr = input_grad.buffer()->address(); + + // Set Runtime Args + auto core_x_offset = core_range.start.x; + auto core_y_offset = core_range.start.y; + for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { + CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; + uint32_t units_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + units_per_core = units_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + units_per_core = units_per_core_group_2; + } else { + TT_THROW("Core not in specified core ranges"); + } + + std::vector reader_args = { + target_addr, + output_grad_addr, + weight_addr, + static_cast(ignore_index), + units_per_core, + tile_offset, + channel_size, + Ct, + Wt, + }; + + std::vector writer_args = {input_grad_addr, units_per_core, tile_offset}; + + SetRuntimeArgs(program, reader_kernel_id, core, reader_args); + SetRuntimeArgs(program, writer_kernel_id, core, writer_args); + + tile_offset += units_per_core; + } + + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; +} + +operation::ProgramWithCallbacks moreh_nll_loss_unreduced_backward_impl_4d( + const Tensor &target, + const std::optional weight, + const Tensor &output_grad, + const Tensor &input_grad, + const int32_t ignore_index, + const CoreRange core_range, + const DeviceComputeKernelConfig compute_kernel_config) { + // split work + auto input_grad_shape = input_grad.get_legacy_shape(); + auto N = input_grad_shape[0]; + auto channel_size = input_grad_shape[1]; + + auto Ct = div_up(channel_size, TILE_WIDTH); + + auto H = input_grad_shape[-2]; + auto W = input_grad_shape[-1]; + auto Ht = H / TILE_HEIGHT; + auto Wt = W / TILE_WIDTH; + auto num_inner_tile = target.volume() / N / TILE_HEIGHT / TILE_WIDTH; + + const bool weight_has_value = weight.has_value(); + + uint32_t core_w = core_range.end.x - core_range.start.x + 1; + uint32_t core_h = core_range.end.y - core_range.start.y + 1; + + uint32_t units_to_divide = input_grad.volume() / H / W * Ht * Wt; + + auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = + split_work_to_cores(core_range, units_to_divide); + + auto arch = input_grad.device()->arch(); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + get_compute_kernel_config_args(arch, compute_kernel_config); + + Program program = Program(); + + // create circular buffers + tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(input_grad.get_dtype()); + + CreateCircularBuffer( + program, + all_cores, + data_format, + { + {CB::c_in0, 1, tt::DataFormat::Int32}, // target + {CB::c_in1, 1}, // output_grad + {CB::c_in2, static_cast(weight_has_value ? Ct : 0)}, // weight + {CB::c_out0, 1}, // input_grad + }); + + // create read/wrtie kernel + const std::vector reader_compile_time_args{ + static_cast(is_dram(target)), + static_cast(is_dram(output_grad)), + static_cast(is_dram(weight))}; + + const std::vector writer_compile_time_args{static_cast(is_dram(input_grad))}; + + std::map reader_defines; + std::map writer_defines; + + if (weight_has_value) { + reader_defines["WEIGHT"] = 1; + } + + if (fp32_dest_acc_en) { + reader_defines["FP32_DEST_ACC_EN"] = 1; + } + + auto reader_kernel_id = CreateReadKernel( + program, + "tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/" + "reader_moreh_nll_loss_unreduced_backward_4d.cpp", + all_cores, + reader_compile_time_args, + reader_defines); + auto writer_kernel_id = CreateWriteKernel( + program, + "tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/kernels/" + "writer_moreh_nll_loss_unreduced_backward.cpp", + all_cores, + writer_compile_time_args, + writer_defines); + + + const auto target_addr = target.buffer()->address(); + const auto output_grad_addr = output_grad.buffer()->address(); + const auto weight_addr = weight_has_value ? weight.value().buffer()->address() : 0; + const auto input_grad_addr = input_grad.buffer()->address(); + + // Set Runtime Args + auto core_x_offset = core_range.start.x; + auto core_y_offset = core_range.start.y; + for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { + CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; + uint32_t units_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + units_per_core = units_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + units_per_core = units_per_core_group_2; + } else { + TT_THROW("Core not in specified core ranges"); + } + + std::vector reader_args = { + target_addr, + output_grad_addr, + weight_addr, + static_cast(ignore_index), + units_per_core, + tile_offset, + num_inner_tile, + channel_size, + Ct, + }; + + std::vector writer_args = {input_grad_addr, units_per_core, tile_offset}; + + SetRuntimeArgs(program, reader_kernel_id, core, reader_args); + SetRuntimeArgs(program, writer_kernel_id, core, writer_args); + + tile_offset += units_per_core; + } + + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; +} + + +} // namespace + +operation::ProgramWithCallbacks moreh_nll_loss_unreduced_backward_impl( + const Tensor &target, + const std::optional weight, + const Tensor &output_grad, + const Tensor &input_grad, + const int32_t ignore_index, + const CoreRange core_range, + const DeviceComputeKernelConfig compute_kernel_config) { + // split work + auto input_grad_shape = input_grad.get_legacy_shape(); + auto input_grad_rank = input_grad_shape.rank(); + + if (input_grad_rank == 2) { + return moreh_nll_loss_unreduced_backward_impl_2d( + target, + weight, + output_grad, + input_grad, + ignore_index, + core_range, + compute_kernel_config); + } + + if (input_grad_rank == 3) { + return moreh_nll_loss_unreduced_backward_impl_3d( + target, + weight, + output_grad, + input_grad, + ignore_index, + core_range, + compute_kernel_config); + } + + if (input_grad_rank >= 4) { + return moreh_nll_loss_unreduced_backward_impl_4d( + target, + weight, + output_grad, + input_grad, + ignore_index, + core_range, + compute_kernel_config); + } + + return moreh_nll_loss_unreduced_backward_impl_4d( + target, + weight, + output_grad, + input_grad, + ignore_index, + core_range, + compute_kernel_config); +} + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.cpp new file mode 100644 index 00000000000..05be109a621 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.cpp @@ -0,0 +1,172 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.hpp" + +#include "tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" + +using namespace tt::constants; +using namespace std; +using namespace tt::tt_metal; + +namespace tt { + +namespace { + +inline void check_tensor(const Tensor &tensor, const std::string &op_name) { + TT_FATAL(tensor.get_layout() == Layout::TILE, "{} only supports tiled layout.", op_name); + TT_FATAL(tensor.get_dtype() == DataType::BFLOAT16, "{} only supports bfloat16.", op_name); + TT_FATAL( + tensor.storage_type() == StorageType::DEVICE, "Operands to {} need to be on device!", op_name); + TT_FATAL( + tensor.buffer() != nullptr, "Operands to {} need to be allocated in buffers on device!", op_name); +} + +inline void check_tensor(std::optional tensor, const std::string &op_name) { + if (!tensor.has_value()) { + return; + } + check_tensor(tensor.value(), op_name); +} + +} // namespace + +namespace operations { +namespace primary { + +void MorehNllLossUnreducedBackward::validate_with_output_tensors( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& output_tensors) const { + TT_FATAL(input_tensors.size() == 2, "Must have 2 input tensors"); + TT_FATAL(optional_input_tensors.size() == 1, "Must have 1 optional input tensors"); + + const auto& target_tensor = input_tensors.at(0); + const auto& output_grad_tensor = input_tensors.at(1); + + const auto& weight_tensor = optional_input_tensors.at(0); + + const auto& input_grad_tensor = output_tensors.at(0); + + TT_FATAL(target_tensor.storage_type() == StorageType::DEVICE, "Operands to nll_loss_unreduced need to be on device!"); + TT_FATAL(target_tensor.buffer() != nullptr, "Operands to nll_loss_unreduced need to be allocated in buffers on device!"); + TT_FATAL((target_tensor.get_layout() == Layout::TILE), "target_tensor to nll_loss_unreduced must be tilized"); + TT_FATAL(target_tensor.get_dtype() == DataType::INT32); + + TT_FATAL(output_grad_tensor.storage_type() == StorageType::DEVICE, "Operands to nll_loss_unreduced need to be on device!"); + TT_FATAL( + output_grad_tensor.buffer() != nullptr, "Operands to nll_loss_unreduced need to be allocated in buffers on device!"); + TT_FATAL((output_grad_tensor.get_layout() == Layout::TILE), "target_tensor to nll_loss_unreduced must be tilized"); + TT_FATAL(output_grad_tensor.get_dtype() == DataType::BFLOAT16); + + if (input_grad_tensor.has_value()) { + TT_FATAL( + input_grad_tensor.value().storage_type() == StorageType::DEVICE, + "Operands to nll_loss need to be on device!"); + TT_FATAL( + input_grad_tensor.value().buffer() != nullptr, + "Operands to nll_loss need to be allocated in buffers on device!"); + TT_FATAL( + (input_grad_tensor.value().get_layout() == Layout::TILE), "target_tensor to nll_loss_unreduced must be tilized"); + TT_FATAL(input_grad_tensor.value().get_dtype() == DataType::BFLOAT16); + } + + if (weight_tensor.has_value()) { + TT_FATAL( + weight_tensor.value().storage_type() == StorageType::DEVICE, + "weight_tensor to nll_loss need to be on device!"); + TT_FATAL( + weight_tensor.value().buffer() != nullptr, + "weight_tensor to nll_loss need to be allocated in buffers on device!"); + TT_FATAL((weight_tensor.value().get_layout() == Layout::TILE), "weight_tensor to nll_loss_unreduced must be in tilized"); + TT_FATAL(weight_tensor.value().get_dtype() == DataType::BFLOAT16); + } +} + +std::vector MorehNllLossUnreducedBackward::compute_output_shapes(const std::vector& input_tensors) const { + // To calculate the output shape, we need the channel_size. However, the required tensors, target and output_grad, + // do not contain the channel_size information. + TT_FATAL(false, "moreh_nll_loss_unreduced_backward not support create output tensors."); + return {input_tensors.at(0).get_legacy_shape()}; +} + +std::vector MorehNllLossUnreducedBackward::create_output_tensors( + const std::vector& input_tensors, const std::vector>& output_tensors) const { + if (output_tensors.at(0).has_value()) { + return {output_tensors.at(0).value()}; + } + + return operation::generic_create_output_tensors( + *this, input_tensors, input_tensors.at(1).get_dtype(), Layout::TILE, this->memory_config); +} + +operation::ProgramWithCallbacks MorehNllLossUnreducedBackward::create_program( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + std::vector& output_tensors) const { + const auto& target = input_tensors.at(0); + const auto& output_grad = input_tensors.at(1); + + const auto& weight = optional_input_tensors.at(0); + + auto& input_grad = output_tensors.at(0); + + return {moreh_nll_loss_unreduced_backward_impl( + target, + weight, + output_grad, + input_grad, + this->ignore_index, + this->core_range, + this->compute_kernel_config)}; +} + +Tensor moreh_nll_loss_unreduced_backward( + const Tensor& target_tensor, + const std::optional weight_tensor, + const Tensor& output_grad_tensor, + const std::optional input_grad_tensor, + const int32_t ignore_index, + const MemoryConfig& memory_config, + std::optional compute_kernel_config) { + auto device = output_grad_tensor.device(); + auto grid_coord = device->compute_with_storage_grid_size(); + const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); + + auto kernel_config_val = + init_device_compute_kernel_config(device->arch(), compute_kernel_config, MathFidelity::HiFi4); + + std::vector output_tensors = {Tensor( + operation::get_workers_for_op_output({target_tensor, output_grad_tensor}, {weight_tensor}))}; + + operation::launch_op( + [ignore_index, memory_config, all_cores, kernel_config_val]( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector>& optional_output_tensors) mutable -> std::vector { + return operation::run( + MorehNllLossUnreducedBackward{ + .ignore_index = ignore_index, + .memory_config = memory_config, + .core_range = all_cores, + .compute_kernel_config = kernel_config_val}, + input_tensors, + optional_input_tensors, + optional_output_tensors); + }, + {target_tensor, output_grad_tensor}, + output_tensors, + {weight_tensor}, + {input_grad_tensor}); + + return output_tensors.at(0); +} + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.hpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.hpp new file mode 100644 index 00000000000..74e7b3c6d4d --- /dev/null +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.hpp @@ -0,0 +1,65 @@ +/* + * SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "tt_dnn/op_library/compute_kernel_config.hpp" +#include "tt_dnn/op_library/operation.hpp" +#include "tt_eager/tensor/tensor.hpp" + +namespace tt { +namespace operations { +namespace primary { + +using namespace tt_metal; + +operation::ProgramWithCallbacks moreh_nll_loss_unreduced_backward_impl( + const Tensor &target, + const std::optional weight, + const Tensor &output_grad, + const Tensor &input_grad, + const int32_t ignore_index, + const CoreRange core_range, + const DeviceComputeKernelConfig compute_kernel_config); + +struct MorehNllLossUnreducedBackward { + int32_t ignore_index; + + const MemoryConfig memory_config; + const CoreRange core_range; // unused for now + const DeviceComputeKernelConfig compute_kernel_config; + + void validate_with_output_tensors( + const std::vector &input_tensors, + const std::vector> &optional_input_tensors, + const std::vector> &output_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors( + const std::vector &input_tensors, const std::vector> &output_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, + const std::vector> &optional_input_tensors, + std::vector &output_tensors) const; + static constexpr auto attribute_names = std::make_tuple("ignore_index", "memory_config", "compute_kernel_config"); + const auto attribute_values() const { return std::make_tuple( + std::cref(this->ignore_index), + std::cref(this->memory_config), + std::cref(this->compute_kernel_config) + ); } +}; + +Tensor moreh_nll_loss_unreduced_backward( + const Tensor &target_tensor, + const std::optional weight_tensor, + const Tensor &output_grad_tensor, + const std::optional input_grad_tensor, + const int32_t ignore_index, + const MemoryConfig &memory_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional compute_kernel_config = std::nullopt); + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_lib/csrc/operations/primary/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/module.hpp index ee69d300650..70760df1a8c 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/module.hpp @@ -31,7 +31,9 @@ #include "tt_dnn/op_library/moreh_mean/moreh_mean_op.hpp" #include "tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward_op.hpp" #include "tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_op.hpp" +#include "tt_dnn/op_library/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.hpp" #include "tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.hpp" +#include "tt_dnn/op_library/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.hpp" #include "tt_dnn/op_library/moreh_norm/moreh_norm_op.hpp" #include "tt_dnn/op_library/moreh_norm_backward/moreh_norm_backward_op.hpp" #include "tt_dnn/op_library/moreh_sgd/moreh_sgd_op.hpp" @@ -364,6 +366,32 @@ void py_module(py::module& m_primary) { py::arg("compute_kernel_config").noconvert() = std::nullopt, "Performs a nll_loss_backward operation. Returns an input_grad tensor."); + // moreh_nll_loss_unreduced + m_primary.def( + "moreh_nll_loss_unreduced", + &moreh_nll_loss_unreduced, + py::arg("input_tensor").noconvert(), + py::arg("target_tensor").noconvert(), + py::arg("weight_tensor").noconvert() = std::nullopt, + py::arg("output_tensor").noconvert() = std::nullopt, + py::arg("ignore_index").noconvert(), + py::arg("memory_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("compute_kernel_config").noconvert() = std::nullopt, + "Performs a nll_loss_unreduced operation. Returns an output tensor."); + + // moreh_nll_loss_unreduced_backward + m_primary.def( + "moreh_nll_loss_unreduced_backward", + &moreh_nll_loss_unreduced_backward, + py::arg("target_tensor").noconvert(), + py::arg("weight_tensor").noconvert() = std::nullopt, + py::arg("output_grad_tensor").noconvert(), + py::arg("input_grad_tensor").noconvert() = std::nullopt, + py::arg("ignore_index").noconvert(), + py::arg("memory_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + py::arg("compute_kernel_config").noconvert() = std::nullopt, + "Performs a nll_loss_unreduced_backward operation. Returns an input_grad tensor."); + // moreh_norm m_primary.def( "moreh_norm",