diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index ac1f2f1775f..831944b77a0 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -1053,3 +1053,40 @@ def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device): golden_tensor = golden_function(in_data1, scalar) assert_with_pcc(golden_tensor, output_tensor, 0.999) + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 2, 32, 64, 64])), + (torch.Size([1, 3, 7, 29, 127])), + (torch.Size([1, 3, 2, 32])), + (torch.Size([1, 6, 49, 97])), + (torch.Size([1, 7, 320])), + (torch.Size([1, 49, 321])), + (torch.Size([4, 32])), + (torch.Size([49, 321])), + ), +) +@pytest.mark.parametrize( + "weight", + [ + [-0.25], + [-2.7], + [0.45], + [6.4], + [2], + [-1], + ], +) +@skip_for_grayskull() +def test_binary_prelu_1D_weight(input_shapes, weight, device): + in_data1 = torch.rand(input_shapes, dtype=torch.bfloat16) * 200 - 100 + input_tensor1 = ttnn.from_torch(in_data1, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn.prelu(input_tensor1, weight) + output_tensor = ttnn.to_torch(output_tensor) + golden_function = ttnn.get_golden_function(ttnn.prelu) + golden_tensor = golden_function(in_data1, weight) + + assert_with_pcc(golden_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_mul.py b/tests/ttnn/unit_tests/operations/eltwise/test_mul.py index a67d4c6f4c7..e82560cd941 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_mul.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_mul.py @@ -108,8 +108,9 @@ def test_multiply_with_scalar_sharded(device, scalar, batch_size, output_memory_ torch_input_tensor_a = torch.rand((batch_size, 16, 384, 384), dtype=torch.float32) torch_output_tensor = scalar * torch_input_tensor_a + # GS has smaller L1 than WH input_tensor_a = ttnn.from_torch( - torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, device=device + torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, device=device ) output = ttnn.mul(input_tensor_a, scalar, memory_config=output_memory_config) output = ttnn.to_torch(output) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp index 6481f71d227..e6d1c061ed7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp @@ -250,6 +250,11 @@ struct ExecutePrelu const Tensor& input_tensor_b, const std::optional& memory_config = std::nullopt); + static Tensor invoke( + const Tensor& input_tensor, + const std::array& weight, + const std::optional& memory_config = std::nullopt); + static Tensor invoke( const Tensor& input_tensor, float scalar, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 8be9e9ef579..4fcf084c781 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -497,6 +497,104 @@ void bind_binary_composite_overload( py::arg("memory_config") = std::nullopt}); } +template +void bind_prelu( + py::module& module, + const binary_operation_t& operation, + const std::string& description, + const std::string& supported_dtype = "BFLOAT16", + const std::string& supported_rank = "2, 3, 4", + const std::string& example_tensor1 = "ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)", + const std::string& example_tensor2 = "ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)", + const std::string& note="") { + auto doc = fmt::format( + R"doc( + {2} + + .. math:: + \mathrm{{output\_tensor}} = \verb|{0}|(\mathrm{{input\_tensor\_a,input\_tensor\_b}}) + + Args: + input_tensor_a (ttnn.Tensor): the input tensor. + input_tensor_b (ttnn.Tensor or List[float] of length 1 or Number): weight. + + Keyword Args: + memory_config (ttnn.MemoryConfig, optional): memory configuration for the operation. Defaults to `None`. + + Returns: + ttnn.Tensor: the output tensor. + + Note: + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - {3} + - TILE + - {4} + + {7} + + Example: + >>> tensor1 = {5} + >>> tensor2 = {6} + >>> output = {1}(tensor1, tensor2/scalar) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + description, + supported_dtype, + supported_rank, + example_tensor1, + example_tensor2, + note); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const Tensor& input_tensor_a, + const Tensor& input_tensor_b, + const std::optional& memory_config) { + return self(input_tensor_a, input_tensor_b, memory_config); + }, + py::arg("input_tensor_a"), + py::arg("weight"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}, + + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const Tensor& input_tensor_a, + float value, + const std::optional& memory_config) { + return self(input_tensor_a, value, memory_config); + }, + py::arg("input_tensor_a"), + py::arg("weight"), + py::kw_only(), + py::arg("memory_config") = std::nullopt}, + + ttnn::pybind_overload_t{ + [](const binary_operation_t& self, + const Tensor& input_tensor_a, + const std::array &weight, + const std::optional& memory_config) { + return self(input_tensor_a, weight, memory_config); + }, + py::arg("input_tensor_a"), + py::arg("weight"), + py::kw_only(), + py::arg("memory_config") = std::nullopt} + ); +} + template void bind_div(py::module& module, const binary_operation_t& operation, const std::string& description, const std::string& math) { auto doc = fmt::format( @@ -1182,14 +1280,15 @@ void py_module(py::module& module) { R"doc(Computes maximum for :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", R"doc(BFLOAT16, BFLOAT8_B)doc"); - detail::bind_binary_composite_overload( + detail::bind_prelu( module, ttnn::prelu, - R"doc(Perform an eltwise-prelu operation. PReLU supports the case where the size of input_tensor_b matches the number of channels in input_tensor_a.)doc", + R"doc(Perform an eltwise-prelu operation.)doc", R"doc(BFLOAT16, BFLOAT8_B)doc", R"doc(2, 3, 4, 5)doc", R"doc(ttnn.from_torch(torch.rand([1, 2, 32, 32], dtype=torch.bfloat16), device=device))doc", - R"doc(ttnn.from_torch(torch.tensor([1, 2], dtype=torch.bfloat16), device=device))doc"); + R"doc(ttnn.from_torch(torch.tensor([1, 2], dtype=torch.bfloat16), device=device))doc", + R"doc(PReLU supports the case where weight is a scalar or 1D list/array of size=1 or a 1D tensor :attr:`input_tensor_b` of size = the second dimension in :attr:`input_tensor_a`)doc"); detail::bind_binary_composite( module, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index f0cee3b7eb8..51d65dac416 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -271,8 +271,13 @@ Tensor _div_no_nan(const Tensor& input_a, const Tensor& input_b, const std::opti return ttnn::where(ttnn::eqz(input_b, output_mem_config), 0, div_result); } -Tensor ExecutePrelu::invoke(const Tensor& input, float scalar, const std::optional& output_mem_config) { - return ttnn::prelu_sfpu(input, scalar); +Tensor ExecutePrelu::invoke(const Tensor& input, float weight, const std::optional& output_mem_config) { + return ttnn::prelu_sfpu(input, weight); +} + +Tensor ExecutePrelu::invoke(const Tensor& input, const std::array& weight, const std::optional& output_mem_config) { + float scalar_weight = weight[0]; + return ttnn::prelu_sfpu(input, scalar_weight); } Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { @@ -286,6 +291,7 @@ Tensor ExecutePrelu::invoke(const Tensor& input_a, const Tensor& input_b, const reshape[1] = s_a[1]; b = ttnn::reshape(input_b, ttnn::Shape(reshape)); } + Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a); return result; }