diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/i0_bw/i0_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/i0_bw/i0_bw.py index 95773e86898..e29f265a4e8 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary_backward/i0_bw/i0_bw.py +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/i0_bw/i0_bw.py @@ -20,7 +20,7 @@ # Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. # Developers can create their own generator functions and pass them to the parameters as inputs. parameters = { - "xfail": { + "nightly": { "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 8) + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 8) + gen_shapes([1, 1], [256, 256], [1, 1], 8), @@ -72,7 +72,7 @@ def run( input_shape ) torch_input_tensor_a = gen_func_with_cast_tt( - partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + partial(torch_random, low=-10, high=10, dtype=torch.float32), input_a_dtype )(input_shape) torch_input_tensor_a.requires_grad = True @@ -100,6 +100,6 @@ def run( output_tensor = ttnn.to_torch(output_tensor) e2e_perf = stop_measuring_time(start_time) - pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.99) + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) # print(pcc) return [pcc, e2e_perf] diff --git a/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_i0.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_i0.py index 20aa2720a18..598b0c47ecd 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_i0.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_i0.py @@ -6,6 +6,7 @@ import pytest import ttnn from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.utils_for_testing import assert_with_pcc @pytest.mark.parametrize( @@ -28,3 +29,49 @@ def test_bw_i0(input_shapes, device): comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "shapes", + [ + [1, 1, 32, 32], + [4, 2, 96, 192], + [4, 7, 21, 133], + [4, 6, 105, 245], + ], +) +def test_i0_bw_range(device, shapes): + torch.manual_seed(3624344) # 16305027 + + high = -10 + low = 10 + torch_input_tensor_a = torch.rand(shapes, dtype=torch.float32, requires_grad=True) * (high - low) + low + + high = 5 + low = -5 + grad_tensor_a = torch.rand(shapes, dtype=torch.float32) * (high - low) + low + + golden_fn = ttnn.get_golden_function(ttnn.i0_bw) + torch_output_tensor = golden_fn(grad_tensor_a, torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat16, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + grad_tensor = ttnn.from_torch( + grad_tensor_a, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat16, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + output_tensor = ttnn.i0_bw(grad_tensor, input_tensor_a, memory_config=ttnn.DRAM_MEMORY_CONFIG) + output_tensor = ttnn.to_torch(output_tensor[0]) + + torch_output_tensor = torch_output_tensor[0] + + pcc = ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) + assert pcc >= 0.9998 diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_unary_i1.py b/tests/ttnn/unit_tests/operations/eltwise/test_unary_i1.py new file mode 100644 index 00000000000..a1d8e5285ec --- /dev/null +++ b/tests/ttnn/unit_tests/operations/eltwise/test_unary_i1.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch + +import ttnn +from models.utility_functions import skip_for_grayskull + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "shapes", + [ + [1, 1, 32, 32], + [4, 2, 96, 192], + [4, 7, 21, 133], + [4, 6, 105, 245], + [64, 64], + [3, 128, 512], + ], +) +def test_i1_range(device, shapes): + torch.manual_seed(0) + + high = 10 + low = -10 + torch_input_tensor_a = torch.rand(shapes, dtype=torch.float32) * (high - low) + low + torch_output_tensor = torch.special.i1(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.float32, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + output_tensor = ttnn.i1(input_tensor_a, memory_config=ttnn.DRAM_MEMORY_CONFIG) + output_tensor = ttnn.to_torch(output_tensor) + + pcc = ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) + assert pcc >= 0.9999 + + +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "shapes", + [ + [4, 2, 96, 192], + [1, 1, 64, 64], + ], +) +def test_i1_zero(device, shapes): + torch.manual_seed(0) + + torch_input_tensor_a = torch.zeros(shapes, dtype=torch.float32) + torch_output_tensor = torch.special.i1(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + layout=ttnn.TILE_LAYOUT, + dtype=ttnn.bfloat16, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + output_tensor = ttnn.i1(input_tensor_a, memory_config=ttnn.DRAM_MEMORY_CONFIG) + output_tensor = ttnn.to_torch(output_tensor) + + assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.9999 diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_sfpu_api.h index 91a4c684384..31f888030db 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_sfpu_api.h @@ -26,3 +26,4 @@ #include "llk_math_eltwise_unary_sfpu_unary_comp.h" #include "llk_math_eltwise_unary_sfpu_fill.h" #include "llk_math_eltwise_unary_sfpu_prelu.h" +#include "llk_math_eltwise_unary_sfpu_i1.h" diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_i0.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_i0.h index 2831c9843c2..c42507e55ad 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_i0.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_i0.h @@ -22,11 +22,11 @@ namespace sfpu { t4) * \ t4) * \ t4) -template +template inline void calculate_i0() { #pragma GCC unroll 0 - for (int d = 0; d < 8; d++) { + for (int d = 0; d < ITERATIONS; d++) { vFloat result = 0.0f; vFloat input = dst_reg[0]; vFloat x = input * input; diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h new file mode 100644 index 00000000000..65759b9086d --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" + +using namespace sfpi; + +namespace ckernel { + +namespace sfpu { + +#define POLYVAL10_I1(coef10, coef9, coef8, coef7, coef6, coef5, coef4, coef3, coef2, coef1, coef0, t2) \ + ((coef0 + \ + (coef1 + \ + (coef2 + \ + (coef3 + \ + (coef4 + (coef5 + (coef6 + (coef7 + (coef8 + (coef9 + coef10 * t2) * t2) * t2) * t2) * t2) * t2) * t2) * \ + t2) * \ + t2) * \ + t2) * \ + t2) + +template +inline void calculate_i1() { +#pragma GCC unroll 0 + + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = 0.0f; + vFloat input = dst_reg[0]; + vFloat x = input * input; + + vFloat derivative = input * POLYVAL10_I1( + 1.24695e-23f, + 6.58387e-21f, + 2.8969e-18f, + 1.04289e-15f, + 3.00351e-13f, + 6.72786e-11f, + 1.13028e-08f, + 1.35634e-06f, + 0.000108507f, + 0.00520833f, + 0.125f, + x); + result = input * 0.5f + derivative; + dst_reg[0] = result; + dst_reg++; + } +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h new file mode 100644 index 00000000000..c1e8a2eeb1c --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_unary_sfpu_init.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "ckernel_sfpu_i1.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_i1_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_i1_op(uint dst_index) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_i1, dst_index, (int)VectorMode::RC); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h index 8a7616784ed..22640cdd93d 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h @@ -59,6 +59,7 @@ enum SfpuType { logical_not_unary, erfinv, i0, + i1, silu, mask, negative, diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h index 97a2e535abd..5a32bc27bd9 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h @@ -20,6 +20,7 @@ #include "llk_math_eltwise_unary_sfpu_trigonometry.h" #include "llk_math_eltwise_unary_sfpu_unary_comp.h" #include "llk_math_eltwise_unary_sfpu_fill.h" +#include "llk_math_eltwise_unary_sfpu_i1.h" namespace ckernel { diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h new file mode 100644 index 00000000000..d590815056d --- /dev/null +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" + +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +#define POLYVAL10_I1(coef10, coef9, coef8, coef7, coef6, coef5, coef4, coef3, coef2, coef1, coef0, t2) \ + ((coef0 + \ + (coef1 + \ + (coef2 + \ + (coef3 + \ + (coef4 + (coef5 + (coef6 + (coef7 + (coef8 + (coef9 + coef10 * t2) * t2) * t2) * t2) * t2) * t2) * t2) * \ + t2) * \ + t2) * \ + t2) * \ + t2) + +template +inline void calculate_i1() { +#pragma GCC unroll 0 + + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = 0.0f; + vFloat input = dst_reg[0]; + vFloat x = input * input; + + vFloat derivative = input * POLYVAL10_I1( + 1.24695e-23f, + 6.58387e-21f, + 2.8969e-18f, + 1.04289e-15f, + 3.00351e-13f, + 6.72786e-11f, + 1.13028e-08f, + 1.35634e-06f, + 0.000108507f, + 0.00520833f, + 0.125f, + x); + result = input * 0.5f + derivative; + dst_reg[0] = result; + dst_reg++; + } +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h new file mode 100644 index 00000000000..dd8705d486e --- /dev/null +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_unary_sfpu_init.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "ckernel_sfpu_i1.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_i1_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_i1_op(uint dst_index) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_i1, dst_index, (int)VectorMode::RC); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu_types.h index dea1cc29e1f..58c2586236d 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu_types.h @@ -57,6 +57,7 @@ enum SfpuType { logical_not_unary, erfinv, i0, + i1, silu, mask, negative, diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h index 78104b0ce32..b1ae609e4fb 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h @@ -37,3 +37,4 @@ #include "llk_math_eltwise_unary_sfpu_left_shift.h" #include "llk_math_eltwise_unary_sfpu_fill.h" #include "llk_math_eltwise_unary_sfpu_prelu.h" +#include "llk_math_eltwise_unary_sfpu_i1.h" diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_i0.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_i0.h index 21bc4601593..91602a61697 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_i0.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_i0.h @@ -23,11 +23,11 @@ namespace sfpu { t4) * \ t4) * \ t4) -template +template inline void calculate_i0() { #pragma GCC unroll 0 - for (int d = 0; d < 8; d++) { + for (int d = 0; d < ITERATIONS; d++) { vFloat result = 0.0f; vFloat input = dst_reg[0]; vFloat x = input * input; diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h new file mode 100644 index 00000000000..65759b9086d --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" + +using namespace sfpi; + +namespace ckernel { + +namespace sfpu { + +#define POLYVAL10_I1(coef10, coef9, coef8, coef7, coef6, coef5, coef4, coef3, coef2, coef1, coef0, t2) \ + ((coef0 + \ + (coef1 + \ + (coef2 + \ + (coef3 + \ + (coef4 + (coef5 + (coef6 + (coef7 + (coef8 + (coef9 + coef10 * t2) * t2) * t2) * t2) * t2) * t2) * t2) * \ + t2) * \ + t2) * \ + t2) * \ + t2) + +template +inline void calculate_i1() { +#pragma GCC unroll 0 + + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = 0.0f; + vFloat input = dst_reg[0]; + vFloat x = input * input; + + vFloat derivative = input * POLYVAL10_I1( + 1.24695e-23f, + 6.58387e-21f, + 2.8969e-18f, + 1.04289e-15f, + 3.00351e-13f, + 6.72786e-11f, + 1.13028e-08f, + 1.35634e-06f, + 0.000108507f, + 0.00520833f, + 0.125f, + x); + result = input * 0.5f + derivative; + dst_reg[0] = result; + dst_reg++; + } +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h new file mode 100644 index 00000000000..514d7e4c465 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_i1.h @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_unary_sfpu_init.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "ckernel_sfpu_i1.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_i1_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_i1_op(uint dst_index) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_i1, dst_index, (int)VectorMode::RC); +} +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index 7b3789c3743..77f70afe745 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -60,6 +60,7 @@ enum SfpuType { logical_not_unary, erfinv, i0, + i1, silu, mask, negative, diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/i1.h b/tt_metal/include/compute_kernel_api/eltwise_unary/i1.h new file mode 100644 index 00000000000..744b2f9a82b --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/i1.h @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_i1.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + +namespace ckernel { + +/** + * Performs element-wise computation of the first order modified Bessel function of the first kind on each element of a + * tile in DST register at index tile_index. The DST register buffer must be in acquired state via *acquire_dst* call. + * This call is blocking and is only available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid + * Range | Required | + * |----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | tile_index | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be + * less than the size of the DST register buffer | True | + */ +ALWI void i1_tile(uint32_t idst) { MATH((llk_math_eltwise_unary_sfpu_i1_op(idst))); } + +/** + * Please refer to documentation for any_init. + */ +ALWI void i1_tile_init() { MATH((llk_math_eltwise_unary_sfpu_i1_init())); } + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h index 204a1559546..443a16db0a1 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h @@ -44,6 +44,10 @@ #include "compute_kernel_api/eltwise_unary/i0.h" #endif +#if SFPU_OP_I1_INCLUDE +#include "compute_kernel_api/eltwise_unary/i1.h" +#endif + #if SFPU_OP_ERFINV_INCLUDE #include "compute_kernel_api/eltwise_unary/erfinv.h" #endif diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp index b7ef3bf6995..394647cdb68 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp @@ -56,6 +56,7 @@ enum class UnaryOpType { ISFINITE, ERFINV, I0, + I1, TAN, RSUB, RDIV, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index 977033f35a8..f18aa992748 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -58,6 +58,7 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_default(UnaryOpType op_type, std: op_init_and_name = {"logical_not_unary_tile_init();", fmt::format("logical_not_unary_tile({});", idst)}; break; case UnaryOpType::I0: op_init_and_name = {"i0_tile_init();", fmt::format("i0_tile({});", idst)}; break; + case UnaryOpType::I1: op_init_and_name = {"i1_tile_init();", fmt::format("i1_tile({});", idst)}; break; case UnaryOpType::ERFINV: op_init_and_name = {"erfinv_tile_init();", fmt::format("erfinv_tile({});", idst)}; break; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp index f98e3d532fc..81eba41d570 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp @@ -104,6 +104,7 @@ template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; +template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp index 46cf0401c29..79d5d22eb5a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp @@ -286,6 +286,7 @@ REGISTER_UNARY_OPERATION(ceil, CEIL); REGISTER_UNARY_OPERATION(gez, GEZ); REGISTER_UNARY_OPERATION(gtz, GTZ); REGISTER_UNARY_OPERATION(i0, I0); +REGISTER_UNARY_OPERATION(i1, I1); REGISTER_UNARY_OPERATION(isfinite, ISFINITE); REGISTER_UNARY_OPERATION(isinf, ISINF); REGISTER_UNARY_OPERATION(isnan, ISNAN); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 1ed1fe0d081..d13bde28419 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -1742,6 +1742,7 @@ void py_module(py::module& module) { detail::bind_unary_operation(module, ttnn::gtz, R"doc(\mathrm{{output\_tensor}}_i= (\mathrm{{input\_tensor_i\ > 0}}))doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_unary_operation(module, ttnn::i0, R"doc(\mathrm{{output\_tensor}}_i = \verb|i0|(\mathrm{{input\_tensor}}_i))doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); + detail::bind_unary_operation(module, ttnn::i1, R"doc(\mathrm{{output\_tensor}}_i = \verb|i1|(\mathrm{{input\_tensor}}_i))doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_unary_operation(module, ttnn::isfinite, R"doc(\mathrm{{output\_tensor}}_i = \verb|isfinite|(\mathrm{{input\_tensor}}_i))doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_unary_operation(module, ttnn::isinf, R"doc(\mathrm{{output\_tensor}}_i = \verb|isinf|(\mathrm{{input\_tensor}}_i))doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_unary_operation(module, ttnn::isnan, R"doc(\mathrm{{output\_tensor}}_i = \verb|isnan|(\mathrm{{input\_tensor}}_i))doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp index 8726732f0b8..a6a9e50998b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp @@ -513,47 +513,14 @@ std::vector ExecuteUnaryBackwardFillZero::invoke( return grad_tensor; } +// name: i0(Tensor self) -> Tensor +// self: grad * at::special_i1(self) +// result: auto_element_wise std::vector ExecuteUnaryBackwardI0::invoke( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { std::vector grad_tensor; - float t_inf = std::numeric_limits::infinity(); - Tensor value = ttnn::multiply( - ttnn::multiply( - ttnn::i0(input, output_mem_config), - ttnn::reciprocal(input, output_mem_config), - std::nullopt, - output_mem_config), - 0.5, - std::nullopt, - output_mem_config); - Tensor result = ttnn::where( - ttnn::ltz(input, output_mem_config), - ttnn::multiply( - grad, - ttnn::subtract( - ttnn::neg(ttnn::i0(input, output_mem_config), output_mem_config), - value, - std::nullopt, - output_mem_config), - std::nullopt, - output_mem_config), - ttnn::multiply( - grad, - ttnn::subtract(ttnn::i0(input, output_mem_config), value, std::nullopt, output_mem_config), - std::nullopt, - output_mem_config), - output_mem_config); - result = ttnn::where( - ttnn::ge( - ttnn::abs(ttnn::i0(input, output_mem_config), output_mem_config), 3.4e+38, std::nullopt, output_mem_config), - t_inf, - result, - output_mem_config); - result = ttnn::where( - ttnn::ge(ttnn::abs(result, output_mem_config), 3.4e+38, std::nullopt, output_mem_config), - t_inf, - result, - output_mem_config); + Tensor i1_input = ttnn::i1(input, output_mem_config); + Tensor result = ttnn::multiply(grad, i1_input, std::nullopt, output_mem_config); grad_tensor.emplace_back(result); return grad_tensor; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index 40909f397bb..5abee661e1a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -1482,7 +1482,7 @@ void py_module(py::module& module) { detail::bind_unary_backward_op( module, ttnn::i0_bw, - R"doc(Performs backward operations for i0 on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc", + R"doc(Performs backward operations for i0 on :attr:`input_tensor` with given :attr:`grad_tensor`. Supported input range is (-10, 10))doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); detail::bind_unary_backward_op(