Skip to content

Commit

Permalink
#13676: i1 op kernel implementation and improve i0_bw pcc (#15325)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #13676

### Problem description
- Current implementation of `i0_bw` uses `reciprocal` op which has an
ongoing issue #14672
- goal is to reimplement `i0_bw` using `i1` with kernel implementation 

### What's changed
- implemented  `ttnn::i1` using eltwise unary kernel with pcc > 0.9999
<img width="1512" alt="Screenshot 2024-11-23 at 1 21 33 AM"
src="https://github.com/user-attachments/assets/8e3aad76-f08c-46fa-bb37-c172c8125040">

- Reimplemented `ttnn.i0_bw` using i1 which gives a pcc ~ 0.9998
<img width="1512" alt="Screenshot 2024-11-23 at 1 22 44 AM"
src="https://github.com/user-attachments/assets/05e10f0c-a74f-41b7-a285-acfdf42e9637">

- Updated sweeps of i0_bw
<img width="1512" alt="Screenshot 2024-11-23 at 1 33 35 AM"
src="https://github.com/user-attachments/assets/0d6be449-fcac-4b6b-8d40-2477eb2fe9b6">

- Profiling :  On main vs On branch

op,count,python min dispatch time (ms),python mean dispatch
time(ms),python mean dispatch + sync time (ms),C++ mean dispatch time
(ms)
**ttnn.i0_bw,800,0.97,0.997,3.717,0.364** (main )
**ttnn.i0_bw,800,0.06,0.062,0.251,0.021** (branch)

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12065876004
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
KalaivaniMCW authored Nov 28, 2024
1 parent 592214d commit f401c2e
Show file tree
Hide file tree
Showing 26 changed files with 434 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs.
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"xfail": {
"nightly": {
"input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 8)
+ gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 8)
+ gen_shapes([1, 1], [256, 256], [1, 1], 8),
Expand Down Expand Up @@ -72,7 +72,7 @@ def run(
input_shape
)
torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
partial(torch_random, low=-10, high=10, dtype=torch.float32), input_a_dtype
)(input_shape)
torch_input_tensor_a.requires_grad = True

Expand Down Expand Up @@ -100,6 +100,6 @@ def run(
output_tensor = ttnn.to_torch(output_tensor)
e2e_perf = stop_measuring_time(start_time)

pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.99)
pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999)
# print(pcc)
return [pcc, e2e_perf]
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize(
Expand All @@ -28,3 +29,49 @@ def test_bw_i0(input_shapes, device):

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"shapes",
[
[1, 1, 32, 32],
[4, 2, 96, 192],
[4, 7, 21, 133],
[4, 6, 105, 245],
],
)
def test_i0_bw_range(device, shapes):
torch.manual_seed(3624344) # 16305027

high = -10
low = 10
torch_input_tensor_a = torch.rand(shapes, dtype=torch.float32, requires_grad=True) * (high - low) + low

high = 5
low = -5
grad_tensor_a = torch.rand(shapes, dtype=torch.float32) * (high - low) + low

golden_fn = ttnn.get_golden_function(ttnn.i0_bw)
torch_output_tensor = golden_fn(grad_tensor_a, torch_input_tensor_a)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
grad_tensor = ttnn.from_torch(
grad_tensor_a,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
output_tensor = ttnn.i0_bw(grad_tensor, input_tensor_a, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor[0])

torch_output_tensor = torch_output_tensor[0]

pcc = ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor)
assert pcc >= 0.9998
71 changes: 71 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_unary_i1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn
from models.utility_functions import skip_for_grayskull


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"shapes",
[
[1, 1, 32, 32],
[4, 2, 96, 192],
[4, 7, 21, 133],
[4, 6, 105, 245],
[64, 64],
[3, 128, 512],
],
)
def test_i1_range(device, shapes):
torch.manual_seed(0)

high = 10
low = -10
torch_input_tensor_a = torch.rand(shapes, dtype=torch.float32) * (high - low) + low
torch_output_tensor = torch.special.i1(torch_input_tensor_a)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.float32,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
output_tensor = ttnn.i1(input_tensor_a, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

pcc = ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor)
assert pcc >= 0.9999


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"shapes",
[
[4, 2, 96, 192],
[1, 1, 64, 64],
],
)
def test_i1_zero(device, shapes):
torch.manual_seed(0)

torch_input_tensor_a = torch.zeros(shapes, dtype=torch.float32)
torch_output_tensor = torch.special.i1(torch_input_tensor_a)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
output_tensor = ttnn.i1(input_tensor_a, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.9999
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_prelu.h"
#include "llk_math_eltwise_unary_sfpu_i1.h"
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ namespace sfpu {
t4) * \
t4) * \
t4)
template <bool APPROXIMATION_MODE>
template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_i0() {
#pragma GCC unroll 0

for (int d = 0; d < 8; d++) {
for (int d = 0; d < ITERATIONS; d++) {
vFloat result = 0.0f;
vFloat input = dst_reg[0];
vFloat x = input * input;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"

using namespace sfpi;

namespace ckernel {

namespace sfpu {

#define POLYVAL10_I1(coef10, coef9, coef8, coef7, coef6, coef5, coef4, coef3, coef2, coef1, coef0, t2) \
((coef0 + \
(coef1 + \
(coef2 + \
(coef3 + \
(coef4 + (coef5 + (coef6 + (coef7 + (coef8 + (coef9 + coef10 * t2) * t2) * t2) * t2) * t2) * t2) * t2) * \
t2) * \
t2) * \
t2) * \
t2)

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_i1() {
#pragma GCC unroll 0

for (int d = 0; d < ITERATIONS; d++) {
vFloat result = 0.0f;
vFloat input = dst_reg[0];
vFloat x = input * input;

vFloat derivative = input * POLYVAL10_I1(
1.24695e-23f,
6.58387e-21f,
2.8969e-18f,
1.04289e-15f,
3.00351e-13f,
6.72786e-11f,
1.13028e-08f,
1.35634e-06f,
0.000108507f,
0.00520833f,
0.125f,
x);
result = input * 0.5f + derivative;
dst_reg[0] = result;
dst_reg++;
}
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_math_eltwise_unary_sfpu_init.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "ckernel_sfpu_i1.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_i1_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::i1, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_i1_op(uint dst_index) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_i1<APPROXIMATE>, dst_index, (int)VectorMode::RC);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ enum SfpuType {
logical_not_unary,
erfinv,
i0,
i1,
silu,
mask,
negative,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_i1.h"

namespace ckernel {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"

#include "sfpi.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

#define POLYVAL10_I1(coef10, coef9, coef8, coef7, coef6, coef5, coef4, coef3, coef2, coef1, coef0, t2) \
((coef0 + \
(coef1 + \
(coef2 + \
(coef3 + \
(coef4 + (coef5 + (coef6 + (coef7 + (coef8 + (coef9 + coef10 * t2) * t2) * t2) * t2) * t2) * t2) * t2) * \
t2) * \
t2) * \
t2) * \
t2)

template <bool APPROXIMATION_MODE, int ITERATIONS>
inline void calculate_i1() {
#pragma GCC unroll 0

for (int d = 0; d < ITERATIONS; d++) {
vFloat result = 0.0f;
vFloat input = dst_reg[0];
vFloat x = input * input;

vFloat derivative = input * POLYVAL10_I1(
1.24695e-23f,
6.58387e-21f,
2.8969e-18f,
1.04289e-15f,
3.00351e-13f,
6.72786e-11f,
1.13028e-08f,
1.35634e-06f,
0.000108507f,
0.00520833f,
0.125f,
x);
result = input * 0.5f + derivative;
dst_reg[0] = result;
dst_reg++;
}
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_math_eltwise_unary_sfpu_init.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "ckernel_sfpu_i1.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_i1_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::i1, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_i1_op(uint dst_index) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_i1<APPROXIMATE, 4>, dst_index, (int)VectorMode::RC);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ enum SfpuType {
logical_not_unary,
erfinv,
i0,
i1,
silu,
mask,
negative,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
#include "llk_math_eltwise_unary_sfpu_left_shift.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_prelu.h"
#include "llk_math_eltwise_unary_sfpu_i1.h"
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ namespace sfpu {
t4) * \
t4) * \
t4)
template <bool APPROXIMATION_MODE>
template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_i0() {
#pragma GCC unroll 0

for (int d = 0; d < 8; d++) {
for (int d = 0; d < ITERATIONS; d++) {
vFloat result = 0.0f;
vFloat input = dst_reg[0];
vFloat x = input * input;
Expand Down
Loading

0 comments on commit f401c2e

Please sign in to comment.