Skip to content

Commit

Permalink
Add forward support for PReLU (#14940)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #8544 

### Problem description
Currently PReLU is aliased to LeakyRelu which is not correct. It needs
to be properly implemented in the eltwise operation.

### What's changed
- Update the prelu implementation In composite structure.
- It also support sfpu(Added in this PR) for certain cases which will
handled soon in #14933 .

### Additional Information
In Torch PReLU, the second input tensor can only have two valid shapes:
either a tensor with a single value (1) or a tensor with a size equal to
the number of input channels (default is 1). Currently, This
implementation only supports cases where it matches the number of
channels. Support for a single-value tensor requires additional handling
at the low-level kernel.

### Checklist
- [ ] [All Post commit CI
](https://github.com/tenstorrent/tt-metal/actions/runs/11796920085)
  • Loading branch information
mouliraj-mcw authored Nov 21, 2024
1 parent 019a5cc commit 0ab59b7
Show file tree
Hide file tree
Showing 25 changed files with 376 additions and 27 deletions.
18 changes: 3 additions & 15 deletions tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,13 @@
from functools import partial

import torch
import random
import ttnn

from tests.sweep_framework.sweep_utils.utils import gen_shapes
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt
from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30

random.seed(0)


# Parameters provided to the test vector generator are defined here.
# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values.
Expand All @@ -45,12 +39,6 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
return False, None


def torch_prelu(x, *args, **kwargs):
weight = kwargs.pop("scalar")
result = torch.nn.functional.prelu(x, torch.tensor(weight, dtype=x.dtype))
return result


# This is the run instructions for the test, defined by the developer.
# The run function must take the above-defined parameters as inputs.
# The runner will call this run function with each test vector, and the returned results from this function will be stored.
Expand All @@ -65,14 +53,14 @@ def run(
*,
device,
) -> list:
data_seed = random.randint(0, 20000000)
torch.manual_seed(data_seed)
torch.manual_seed(0)

torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_shape)

torch_output_tensor = torch_prelu(torch_input_tensor_a, scalar=weight)
golden_function = ttnn.get_golden_function(ttnn.prelu)
torch_output_tensor = golden_function(torch_input_tensor_a, weight)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def test_run_eltwise_leaky_relu_op(
)

@pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5])
@skip_for_grayskull()
def test_run_eltwise_prelu(
self,
input_shapes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def test_scalarB_leaky_relu(device, h, w, scalar):
run_activation_test_leaky_relu(device, h, w, scalar, ttnn.leaky_relu)


@skip_for_grayskull()
@pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5])
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
Expand Down
60 changes: 60 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
compare_pcc,
compare_equal,
)
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import is_grayskull, skip_for_grayskull


Expand Down Expand Up @@ -993,3 +994,62 @@ def test_binary_lcm_ttnn(input_shapes, device):

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 2, 32, 64, 64])),
(torch.Size([1, 3, 7, 29, 127])),
(torch.Size([1, 3, 2, 32])),
(torch.Size([1, 6, 49, 97])),
(torch.Size([1, 7, 320])),
(torch.Size([1, 49, 321])),
(torch.Size([4, 32])),
(torch.Size([49, 321])),
),
)
def test_binary_prelu_ttnn(input_shapes, device):
in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100
channels = input_shapes[1]
in_data2 = torch.rand((channels,), dtype=torch.bfloat16) * 200 - 100

input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor2 = ttnn.from_torch(in_data2, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.prelu(input_tensor1, input_tensor2)
output_tensor = ttnn.to_torch(output_tensor)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, in_data2)

assert_with_pcc(golden_tensor, output_tensor, 0.999)


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 2, 32, 64, 64])),
(torch.Size([1, 3, 7, 29, 127])),
(torch.Size([1, 3, 2, 32])),
(torch.Size([1, 6, 49, 97])),
(torch.Size([1, 7, 320])),
(torch.Size([1, 49, 321])),
(torch.Size([4, 32])),
(torch.Size([49, 321])),
),
)
@pytest.mark.parametrize(
"scalar",
{-0.25, -2.7, 0.45, 6.4},
)
@skip_for_grayskull()
def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device):
in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100
input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.prelu(input_tensor1, scalar)
output_tensor = ttnn.to_torch(output_tensor)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, scalar)

assert_with_pcc(golden_tensor, output_tensor, 0.999)
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_prelu.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_prelu(const uint value) {

// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat init = c_value.f;

#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++)
{
vFloat a = dst_reg[0];
v_if(a < 0.0f) {
a = a * init;
}
v_endif;
dst_reg[0] = a;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_prelu.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::prelu, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_prelu<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,6 @@ enum SfpuType {
ceil,
unused,
cumsum,
fill
fill,
prelu,
};
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
#include "llk_math_eltwise_unary_sfpu_right_shift.h"
#include "llk_math_eltwise_unary_sfpu_left_shift.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_prelu.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {


template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_prelu(uint value) {
// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat init = c_value.f;

#pragma GCC unroll 8
for (int d = 0; d < ITERATIONS; d++)
{
vFloat a = dst_reg[0];
v_if(a < 0.0f) {
a = a * init;
}
v_endif;
dst_reg[0] = a;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_prelu.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::prelu, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_prelu<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum SfpuType {
reciprocal,
sqrt,
lrelu,
prelu,
power,
square,
tanh_derivative,
Expand Down
43 changes: 43 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once


#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_prelu.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif



namespace ckernel {

/**
* Performs element-wise prelu operation. The value to be prelued in the tile is provided as const param0. The DST register buffer must be in
* acquired state via *acquire_dst* call. This call is blocking and is only
* available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | param0 | Constant value that is being multiplied if the input is lesser than 0 | uint32_t | | True |
*/
ALWI void prelu_tile(uint32_t idst, uint32_t param0) {
MATH((llk_math_eltwise_unary_sfpu_prelu<APPROX>(idst, param0)));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void prelu_tile_init() { MATH((llk_math_eltwise_unary_sfpu_prelu_init<APPROX>())); }


} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@
#include "compute_kernel_api/eltwise_unary/softplus.h"
#endif

#if SFPU_OP_PRELU_INCLUDE
#include "compute_kernel_api/eltwise_unary/prelu.h"
#endif

#if SFPU_OP_DROPOUT_INCLUDE
#include "compute_kernel_api/eltwise_unary/dropout.h"
#endif
Expand Down
17 changes: 17 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,20 @@ struct ExecuteMinimum

};

struct ExecutePrelu
{
static Tensor invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float scalar,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

};

} // namespace binary
} // namespace operations

Expand Down Expand Up @@ -306,5 +320,8 @@ constexpr auto gcd = ttnn::register_operation_with_auto_launch_op<
constexpr auto lcm = ttnn::register_operation_with_auto_launch_op<
"ttnn::lcm",
operations::binary::ExecuteLCM>();
constexpr auto prelu = ttnn::register_operation_with_auto_launch_op<
"ttnn::prelu",
operations::binary::ExecutePrelu>();

} // namespace ttnn
Loading

0 comments on commit 0ab59b7

Please sign in to comment.