diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py index 32f8a179d76..556aeab6bb7 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py @@ -139,12 +139,7 @@ def run_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weight @pytest.mark.parametrize( "shape", - [ - (5, 10), - (3000, 100), - (200, 100, 90), - (5, 50, 2, 7, 50, 70), - ], + [[5, 10], [3000, 100], [200, 100, 90], [5, 50, 2, 7, 50, 70]], ) @pytest.mark.parametrize("ignore_index", [1]) @pytest.mark.parametrize("reduction", ["mean", "sum"]) @@ -158,9 +153,9 @@ def test_moreh_nll_loss(shape, ignore_index, reduction, none_weight, device): @pytest.mark.parametrize( "shape", [ - (5, 10), - (5, 6, 7), - (5, 6, 8, 9), + [5, 10], + [5, 6, 7], + [5, 6, 8, 9], ], ) @pytest.mark.parametrize("reduction", ["mean", "sum"]) @@ -172,15 +167,17 @@ def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_prog for _ in range(2): run_moreh_nll_loss(shape, ignore_idx, reduction, none_weight, device) + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) @pytest.mark.parametrize( "shape", [ - (400, 300), - (20, 300, 320), - (3, 4, 32 * 5, 32 * 6), - (5, 2, 5, 40, 70), + [400, 300], + [20, 300, 320], + [3, 4, 32 * 5, 32 * 6], + [5, 2, 5, 40, 70], ], ) @pytest.mark.parametrize("ignore_index", [1]) @@ -195,9 +192,9 @@ def test_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weigh @pytest.mark.parametrize( "shape", [ - (2, 3), - (2, 3, 4), - (2, 3, 5, 4), + [2, 3], + [2, 3, 4], + [2, 3, 5, 4], ], ) @pytest.mark.parametrize("reduction_mean", [True, False]) @@ -209,14 +206,16 @@ def test_moreh_nll_loss_backward_test_callback(shape, reduction_mean, none_weigh for _ in range(2): run_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weight, device) + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) @pytest.mark.parametrize( "shape", [ - (5, 10), - (10, 20, 30), - (10, 20, 30, 40), + [5, 10], + [10, 20, 30], + [10, 20, 30, 40], ], ) @pytest.mark.parametrize("ignore_index", [1]) @@ -236,9 +235,9 @@ def test_moreh_nll_loss_compute_kernel_options( @pytest.mark.parametrize( "shape", [ - (5, 10), - (10, 20, 30), - (10, 20, 30, 40), + [5, 10], + [10, 20, 30], + [10, 20, 30, 40], ], ) @pytest.mark.parametrize("reduction_mean", [True, False]) diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1_large.cpp new file mode 100644 index 00000000000..1f0ccc88cc5 --- /dev/null +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1_large.cpp @@ -0,0 +1,102 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" + +void kernel_main() { + uint32_t i = 0; + auto target_addr = get_arg_val(i++); + auto weight_addr = get_arg_val(i++); + auto ignore_index = static_cast(get_arg_val(i++)); + auto num_units_per_core = get_arg_val(i++); + auto start_id = get_arg_val(i++); + auto N = get_arg_val(i++); + auto C = get_arg_val(i++); + auto weight_num_tile = get_arg_val(i++); + auto element_size = get_arg_val(i++); + auto target_element_size = get_arg_val(i++); + + constexpr uint32_t cb_target = tt::CB::c_in0; + constexpr uint32_t cb_weight = tt::CB::c_in1; + + constexpr uint32_t cb_output = tt::CB::c_out0; + + // ublocks size defined in tiles + const uint32_t target_tile_bytes = get_tile_size(cb_target); + + constexpr bool target_is_dram = get_compile_time_arg_val(0) == 1; +#if defined(WEIGHT) + constexpr bool weight_is_dram = get_compile_time_arg_val(1) == 1; + constexpr bool weight_has_value = get_compile_time_arg_val(2) == 1; +#endif + + const InterleavedAddrGen addrg_target = { + .bank_base_address = target_addr, .page_size = target_tile_bytes}; + +#if defined(WEIGHT) + const uint32_t weight_tile_bytes = get_tile_size(cb_weight); + auto weight_element_size = weight_tile_bytes / 1024; + const DataFormat weight_data_format = get_dataformat(cb_weight); + const InterleavedAddrGen addrg_weight = { + .bank_base_address = weight_addr, + .page_size = weight_tile_bytes, + }; +#endif + + constexpr uint32_t onetile = 1; + + Scalar one, zero; + one.f = 1.0f; + zero.f = 0.0f; + + const auto u16_one = uint16_t(one.u >> 16); + const auto u16_zero = uint16_t(zero.u >> 16); + + uint32_t end_id = start_id + num_units_per_core; + for (uint32_t i = start_id; i < end_id; ++i) { + // target: (N, d1, d2, .. dk) + uint32_t target_noc_id = i; + read_tile(cb_target, addrg_target, target_noc_id); + + cb_reserve_back(cb_output, onetile); + cb_wait_front(cb_target, onetile); + + auto output_l1_ptr = get_write_ptr(cb_output); + auto target_l1_ptr = get_read_ptr(cb_target); + + for (uint32_t h = 0; h < TILE_HEIGHT; h++) { + for (uint32_t w = 0; w < TILE_WIDTH; w++) { + uint32_t inout_idx = h * TILE_WIDTH + w; + int32_t target_val = target_l1_ptr[inout_idx]; + if (target_val != ignore_index) { + if (0 <= target_val && target_val < static_cast(C)) { +#if defined(WEIGHT) + uint32_t target_idx = target_val; + + uint32_t noc_id = target_idx / TILE_WIDTH; + uint32_t weight_tilized_idx = get_tilized_idx(0, target_idx); + read_value(cb_weight, addrg_weight, noc_id, weight_tilized_idx); + + cb_wait_front(cb_weight, onetile); + auto weight_l1_ptr = get_read_ptr(cb_weight); + + output_l1_ptr[inout_idx] = weight_l1_ptr[weight_tilized_idx]; + + cb_pop_front(cb_weight, onetile); +#else + output_l1_ptr[inout_idx] = u16_one; +#endif + } else { + output_l1_ptr[inout_idx] = u16_zero; + } + } else { + output_l1_ptr[inout_idx] = u16_zero; + } + } + } + cb_push_back(cb_output, onetile); + + cb_pop_front(cb_target, onetile); + } +} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp index 26f5a6281e7..35ff97a0194 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp @@ -2,12 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/run_operation.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/work_split.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" +#include "ttnn/run_operation.hpp" using namespace tt::constants; using namespace std; @@ -48,28 +48,56 @@ operation::ProgramWithCallbacks moreh_nll_loss_step1_impl( auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(core_range, units_to_divide); - auto arch = target.device()->arch(); + auto* device = target.device(); + auto arch = device->arch(); auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); // create circular buffers - tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - - auto fp32_dest_acc_en_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - - uint32_t weight_num_tile = div_up(channel_size, TILE_WIDTH); - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, 1, tt::DataFormat::Int32}, // traget - {CB::c_in1, weight_num_tile}, // weight - {CB::c_intermed0, 1, fp32_dest_acc_en_data_format}, // tmp_weight - {CB::c_out0, 1}, // output - }); + const auto target_data_format = tt_metal::datatype_to_dataformat_converter(target.get_dtype()); + const auto data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + const auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; + + const auto target_tile_size = tt_metal::detail::TileSize(target_data_format); + const auto data_tile_size = tt_metal::detail::TileSize(data_format); + const auto intermed_tile_size = tt_metal::detail::TileSize(intermed_data_format); + + const uint32_t available_L1 = device->l1_size_per_core() - L1_UNRESERVED_BASE; + + uint32_t target_num_tile = 1; + uint32_t weight_num_tile = weight_has_value ? div_up(channel_size, TILE_WIDTH) : 0; + uint32_t intermed_num_tile = 1; + uint32_t output_num_tile = 1; + uint32_t cb_usage = target_num_tile * target_tile_size + weight_num_tile * data_tile_size + + intermed_num_tile * intermed_tile_size + output_num_tile * data_tile_size; + + const bool use_large_algorithm = cb_usage >= available_L1;; + + if (use_large_algorithm) { + CreateCircularBuffer( + program, + all_cores, + data_format, + { + {CB::c_in0, 1, tt::DataFormat::Int32}, // traget + {CB::c_in1, 1}, // weight + {CB::c_intermed0, 1, intermed_data_format}, // tmp_weight + {CB::c_out0, 1}, // output + }); + } else { + CreateCircularBuffer( + program, + all_cores, + data_format, + { + {CB::c_in0, 1, tt::DataFormat::Int32}, // traget + {CB::c_in1, weight_num_tile}, // weight + {CB::c_intermed0, 1, intermed_data_format}, // tmp_weight + {CB::c_out0, 1}, // output + }); + } // create read/wrtie kernel const std::vector reader_compile_time_args{ @@ -89,19 +117,17 @@ operation::ProgramWithCallbacks moreh_nll_loss_step1_impl( if (fp32_dest_acc_en) { reader_defines["FP32_DEST_ACC_EN"] = 1; } - - auto reader_kernel_id = CreateReadKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1.cpp", - all_cores, - reader_compile_time_args, - reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/writer_moreh_nll_loss_step1.cpp", - all_cores, - writer_compile_time_args, - writer_defines); + const auto reader_kernel_file = + use_large_algorithm ? "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1_large.cpp" + : "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/reader_moreh_nll_loss_step1.cpp"; + const auto writer_kernel_file = + "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/kernels/" + "writer_moreh_nll_loss_step1.cpp"; + + auto reader_kernel_id = + CreateReadKernel(program, reader_kernel_file, all_cores, reader_compile_time_args, reader_defines); + auto writer_kernel_id = + CreateWriteKernel(program, writer_kernel_file, all_cores, writer_compile_time_args, writer_defines); const auto target_addr = target.buffer()->address(); const auto weight_addr = weight_has_value ? weight.value().buffer()->address() : 0;