From 59c489106528bc9e32ebddf42f76ce358a0c842b Mon Sep 17 00:00:00 2001 From: umadevimcw Date: Tue, 19 Nov 2024 06:21:21 +0000 Subject: [PATCH] #13625: Update lerp op params --- .../eltwise/ternary/ternary_composite.hpp | 8 ++-- .../eltwise/ternary/ternary_composite_op.cpp | 20 +++++----- .../eltwise/ternary/ternary_pybind.hpp | 40 +++++++++---------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite.hpp index d1f378bd425..d54754a35c6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite.hpp @@ -31,12 +31,12 @@ template struct ExecuteTernaryCompositeLerp { static Tensor invoke( - const Tensor& input_tensor_a, - const Tensor& input_tensor_b, - const Tensor& input_tensor_c, + const Tensor& input_tensor, + const Tensor& end, + const Tensor& weight, const std::optional& memory_config = std::nullopt) { - return OpHandler::handle(input_tensor_a, input_tensor_b, input_tensor_c, memory_config); + return OpHandler::handle(input_tensor, end, weight, memory_config); } static Tensor invoke( diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp index 211ec8c9f00..18bfc35c066 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp @@ -51,20 +51,20 @@ Tensor _addcdiv( } // lerp(input, end, weight) = start weight * (end - start) -Tensor _lerp_overload(const Tensor& input_a, const Tensor& input_b, float value, const std::optional& output_mem_config) { - TT_FATAL(input_a.dtype() == input_b.dtype(), "Expected the same dtype as start (input_a), for end (input_b)"); - Tensor t_diff = ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config); - Tensor t_mul = ttnn::multiply(t_diff, value, std::nullopt, output_mem_config); - Tensor result = ttnn::add(input_a, t_mul, std::nullopt, output_mem_config); +Tensor _lerp_overload(const Tensor& input, const Tensor& end, float weight, const std::optional& output_mem_config) { + TT_FATAL(input.dtype() == end.dtype(), "Expected the same dtype as start (input), for end"); + Tensor t_diff = ttnn::subtract(end, input, std::nullopt, output_mem_config); + Tensor t_mul = ttnn::multiply(t_diff, weight, std::nullopt, output_mem_config); + Tensor result = ttnn::add(input, t_mul, std::nullopt, output_mem_config); return result; } -Tensor _lerp(const Tensor& input_a, const Tensor& input_b, const Tensor& input_c, const std::optional& output_mem_config) { - TT_FATAL(input_a.dtype() == input_b.dtype(), "Expected the same dtype as start (input_a), for end (input_b)"); - TT_FATAL(input_a.dtype() == input_c.dtype(), "Expected the same dtype as start (input_a), for weight (input_c)"); +Tensor _lerp(const Tensor& input, const Tensor& end, const Tensor& weight, const std::optional& output_mem_config) { + TT_FATAL(input.dtype() == end.dtype(), "Expected the same dtype as start (input), for end"); + TT_FATAL(input.dtype() == weight.dtype(), "Expected the same dtype as start (input), for weight"); Tensor t_diff = ttnn::multiply( - ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config), input_c, std::nullopt, output_mem_config); - Tensor result = ttnn::add(input_a, t_diff, std::nullopt, output_mem_config); + ttnn::subtract(end, input, std::nullopt, output_mem_config), weight, std::nullopt, output_mem_config); + Tensor result = ttnn::add(input, t_diff, std::nullopt, output_mem_config); return result; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp index a0d32aa672e..020328a4293 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp @@ -213,12 +213,12 @@ void bind_ternary_lerp(py::module& module, const ternary_operation_t& operation, {2} .. math:: - \mathrm{{output\_tensor}} = \verb|{0}|(\mathrm{{input\_tensor\_a, input\_tensor\_b, input\_tensor\_c}}) + \mathrm{{output\_tensor}} = \verb|{0}|(\mathrm{{input, end, weight}}) Args: - input_tensor_a (ttnn.Tensor): the input tensor. - input_tensor_b (ttnn.Tensor): the input tensor. - input_tensor_c (ttnn.Tensor or Number): the input tensor. + input (ttnn.Tensor): the input tensor with the starting points. + end (ttnn.Tensor): the tensor with the ending points. + weight (ttnn.Tensor or float): the weight for the interpolation formula. Keyword Args: @@ -240,7 +240,7 @@ void bind_ternary_lerp(py::module& module, const ternary_operation_t& operation, bfloat8_b/bfloat4_b supports only on TILE_LAYOUT - input_tensor_b, input_tensor_c should have same dtype as input_tensor_a + end, weight tensors should have same dtype as input Example: >>> tensor1 = ttnn.from_torch(torch.tensor([[1, 0], [1, 0]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device) @@ -258,29 +258,29 @@ void bind_ternary_lerp(py::module& module, const ternary_operation_t& operation, doc, ttnn::pybind_overload_t{ [](const ternary_operation_t& self, - const Tensor& input_tensor_a, - const Tensor& input_tensor_b, - const Tensor& input_tensor_c, + const Tensor& input, + const Tensor& end, + const Tensor& weight, const std::optional& memory_config) { - return self(input_tensor_a, input_tensor_b, input_tensor_c, memory_config); + return self(input, end, weight, memory_config); }, - py::arg("input_tensor_a"), - py::arg("input_tensor_b"), - py::arg("input_tensor_c"), + py::arg("input"), + py::arg("end"), + py::arg("weight"), py::kw_only(), py::arg("memory_config") = std::nullopt}, ttnn::pybind_overload_t{ [](const ternary_operation_t& self, - const Tensor& input_tensor_a, - const Tensor& input_tensor_b, - float value, + const Tensor& input, + const Tensor& end, + float weight, const std::optional& memory_config) { - return self(input_tensor_a, input_tensor_b, value, memory_config); + return self(input, end, weight, memory_config); }, - py::arg("input_tensor_a"), - py::arg("input_tensor_b"), - py::arg("value"), + py::arg("input"), + py::arg("end"), + py::arg("weight"), py::kw_only(), py::arg("memory_config") = std::nullopt}); } @@ -383,7 +383,7 @@ void py_module(py::module& module) { detail::bind_ternary_lerp( module, ttnn::lerp, - R"doc(Computes Lerp on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc"); + R"doc(Computes Lerp on :attr:`input`, :attr:`end` and :attr:`weight` and returns the tensor with the same layout as :attr:`input`)doc"); detail::bind_ternary_mac( module,