Skip to content

Commit

Permalink
#13779: Optimize pow (#15534)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #13779 


### What's changed

- optimised the existing pow implementation with binary exponentiation.
  • Loading branch information
mouliraj-mcw authored Dec 3, 2024
1 parent 3d9149d commit 6e2373c
Showing 4 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/eltwise/test_unary.py
Original file line number Diff line number Diff line change
@@ -337,7 +337,7 @@ def test_logit(device, h, w, scalar):
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
def test_pow(device, h, w, scalar):
run_unary_test_with_float(device, h, w, scalar, ttnn.pow, pcc=0.9)
run_unary_test_with_float(device, h, w, scalar, ttnn.pow, pcc=0.999)


@pytest.mark.parametrize("lower_limit", [0, 1.0, 2])
Original file line number Diff line number Diff line change
@@ -16,10 +16,15 @@ template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_power_iterative(const uint exponent) {
#pragma GCC unroll 8
for (int d = 0; d < 8; d++) {
uint exp = exponent;
vFloat in = dst_reg[0];
vFloat result = 1.0f;
for (uint i = 0; i < exponent; i++) {
result *= in;
while (exp > 0) {
if (exp & 1){
result *= in;
}
in *= in;
exp >>= 1;
}
dst_reg[0] = result;
dst_reg++;
Original file line number Diff line number Diff line change
@@ -17,10 +17,15 @@ template <bool APPROXIMATION_MODE, int ITERATIONS = 4>
inline void calculate_power_iterative(const uint exponent) {
#pragma GCC unroll 4
for (int d = 0; d < ITERATIONS; d++) {
uint exp = exponent;
vFloat in = dst_reg[0];
vFloat result = 1.0f;
for (uint i = 0; i < exponent; i++) {
result *= in;
while (exp > 0) {
if (exp & 1){
result *= in;
}
in *= in;
exp >>= 1;
}
dst_reg[0] = result;
dst_reg++;
Original file line number Diff line number Diff line change
@@ -17,10 +17,15 @@ template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_power_iterative(const uint exponent) {
#pragma GCC unroll 8
for (int d = 0; d < 8; d++) {
uint exp = exponent;
vFloat in = dst_reg[0];
vFloat result = 1.0f;
for (uint i = 0; i < exponent; i++) {
result *= in;
while (exp > 0) {
if (exp & 1){
result *= in;
}
in *= in;
exp >>= 1;
}
dst_reg[0] = result;
dst_reg++;

0 comments on commit 6e2373c

Please sign in to comment.