Skip to content

Commit

Permalink
#14933: Support PRelu for single element weight array (#15261)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #14933

### Problem description
Need Support PRelu to accept single element weight array as input

### What's changed
Added support for the same

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/11960853362
https://github.com/tenstorrent/tt-metal/actions/runs/11987294776
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
KalaivaniMCW authored Nov 23, 2024
1 parent 5ad4e34 commit 6f491aa
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 6 deletions.
37 changes: 37 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,3 +1053,40 @@ def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device):
golden_tensor = golden_function(in_data1, scalar)

assert_with_pcc(golden_tensor, output_tensor, 0.999)


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 2, 32, 64, 64])),
(torch.Size([1, 3, 7, 29, 127])),
(torch.Size([1, 3, 2, 32])),
(torch.Size([1, 6, 49, 97])),
(torch.Size([1, 7, 320])),
(torch.Size([1, 49, 321])),
(torch.Size([4, 32])),
(torch.Size([49, 321])),
),
)
@pytest.mark.parametrize(
"weight",
[
[-0.25],
[-2.7],
[0.45],
[6.4],
[2],
[-1],
],
)
@skip_for_grayskull()
def test_binary_prelu_1D_weight(input_shapes, weight, device):
in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100
input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.prelu(input_tensor1, weight)
output_tensor = ttnn.to_torch(output_tensor)
golden_function = ttnn.get_golden_function(ttnn.prelu)
golden_tensor = golden_function(in_data1, weight)

assert_with_pcc(golden_tensor, output_tensor, 0.999)
3 changes: 2 additions & 1 deletion tests/ttnn/unit_tests/operations/eltwise/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def test_multiply_with_scalar_sharded(device, scalar, batch_size, output_memory_
torch_input_tensor_a = torch.rand((batch_size, 16, 384, 384), dtype=torch.float32)
torch_output_tensor = scalar * torch_input_tensor_a

# GS has smaller L1 than WH
input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, device=device
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, device=device
)
output = ttnn.mul(input_tensor_a, scalar, memory_config=output_memory_config)
output = ttnn.to_torch(output)
Expand Down
5 changes: 5 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ struct ExecutePrelu
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
const std::array<float, 1>& weight,
const std::optional<MemoryConfig>& memory_config = std::nullopt);

static Tensor invoke(
const Tensor& input_tensor,
float scalar,
Expand Down
105 changes: 102 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,104 @@ void bind_binary_composite_overload(
py::arg("memory_config") = std::nullopt});
}

template <typename binary_operation_t>
void bind_prelu(
py::module& module,
const binary_operation_t& operation,
const std::string& description,
const std::string& supported_dtype = "BFLOAT16",
const std::string& supported_rank = "2, 3, 4",
const std::string& example_tensor1 = "ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)",
const std::string& example_tensor2 = "ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)",
const std::string& note="") {
auto doc = fmt::format(
R"doc(
{2}
.. math::
\mathrm{{output\_tensor}} = \verb|{0}|(\mathrm{{input\_tensor\_a,input\_tensor\_b}})
Args:
input_tensor_a (ttnn.Tensor): the input tensor.
input_tensor_b (ttnn.Tensor or List[float] of length 1 or Number): weight.
Keyword Args:
memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`.
Returns:
ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- TILE
- {4}
{7}
Example:
>>> tensor1 = {5}
>>> tensor2 = {6}
>>> output = {1}(tensor1, tensor2/scalar)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description,
supported_dtype,
supported_rank,
example_tensor1,
example_tensor2,
note);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor_a, input_tensor_b, memory_config);
},
py::arg("input_tensor_a"),
py::arg("weight"),
py::kw_only(),
py::arg("memory_config") = std::nullopt},

ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor_a,
float value,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor_a, value, memory_config);
},
py::arg("input_tensor_a"),
py::arg("weight"),
py::kw_only(),
py::arg("memory_config") = std::nullopt},

ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const Tensor& input_tensor_a,
const std::array<float, 1> &weight,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor_a, weight, memory_config);
},
py::arg("input_tensor_a"),
py::arg("weight"),
py::kw_only(),
py::arg("memory_config") = std::nullopt}
);
}

template <typename binary_operation_t>
void bind_div(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& math) {
auto doc = fmt::format(
Expand Down Expand Up @@ -1182,14 +1280,15 @@ void py_module(py::module& module) {
R"doc(Computes maximum for :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_binary_composite_overload(
detail::bind_prelu(
module,
ttnn::prelu,
R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc",
R"doc(Perform an eltwise-prelu operation.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc",
R"doc(2, 3, 4, 5)doc",
R"doc(ttnn.from_torch(torch.rand([1, 2, 32, 32], dtype=torch.bfloat16), device=device))doc",
R"doc(ttnn.from_torch(torch.tensor([1, 2], dtype=torch.bfloat16), device=device))doc");
R"doc(ttnn.from_torch(torch.tensor([1, 2], dtype=torch.bfloat16), device=device))doc",
R"doc(PReLU supports the case where weight is a scalar or 1D list/array of size=1 or a 1D tensor :attr:`input_tensor_b` of size = the second dimension in :attr:`input_tensor_a`)doc");

detail::bind_binary_composite(
module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,13 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti
return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result);
}

Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::optional<MemoryConfig>& output_mem_config) {
return ttnn::prelu_sfpu(input, scalar);
Tensor ExecutePrelu::invoke(const Tensor& input, float weight, const std::optional<MemoryConfig>& output_mem_config) {
return ttnn::prelu_sfpu(input, weight);
}

Tensor ExecutePrelu::invoke(const Tensor& input, const std::array<float, 1>& weight, const std::optional<MemoryConfig>& output_mem_config) {
float scalar_weight = weight[0];
return ttnn::prelu_sfpu(input, scalar_weight);
}

Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Expand All @@ -286,6 +291,7 @@ Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const
reshape[1] = s_a[1];
b = ttnn::reshape(input_b, ttnn::Shape(reshape));
}

Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a);
return result;
}
Expand Down

0 comments on commit 6f491aa

Please sign in to comment.