Skip to content

Commit

Permalink
#13625: Update lerp op params
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Nov 19, 2024
1 parent b4066ac commit 59c4891
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ template <TernaryCompositeOpType ternary_comp_op_type>
struct ExecuteTernaryCompositeLerp
{
static Tensor invoke(
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const Tensor& input_tensor_c,
const Tensor& input_tensor,
const Tensor& end,
const Tensor& weight,
const std::optional<MemoryConfig>& memory_config = std::nullopt)
{
return OpHandler<ternary_comp_op_type>::handle(input_tensor_a, input_tensor_b, input_tensor_c, memory_config);
return OpHandler<ternary_comp_op_type>::handle(input_tensor, end, weight, memory_config);
}

static Tensor invoke(
Expand Down
20 changes: 10 additions & 10 deletions ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@ Tensor _addcdiv(
}

// lerp(input, end, weight) = start weight * (end - start)
Tensor _lerp_overload(const Tensor& input_a, const Tensor& input_b, float value, const std::optional<MemoryConfig>& output_mem_config) {
TT_FATAL(input_a.dtype() == input_b.dtype(), "Expected the same dtype as start (input_a), for end (input_b)");
Tensor t_diff = ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config);
Tensor t_mul = ttnn::multiply(t_diff, value, std::nullopt, output_mem_config);
Tensor result = ttnn::add(input_a, t_mul, std::nullopt, output_mem_config);
Tensor _lerp_overload(const Tensor& input, const Tensor& end, float weight, const std::optional<MemoryConfig>& output_mem_config) {
TT_FATAL(input.dtype() == end.dtype(), "Expected the same dtype as start (input), for end");
Tensor t_diff = ttnn::subtract(end, input, std::nullopt, output_mem_config);
Tensor t_mul = ttnn::multiply(t_diff, weight, std::nullopt, output_mem_config);
Tensor result = ttnn::add(input, t_mul, std::nullopt, output_mem_config);
return result;
}

Tensor _lerp(const Tensor& input_a, const Tensor& input_b, const Tensor& input_c, const std::optional<MemoryConfig>& output_mem_config) {
TT_FATAL(input_a.dtype() == input_b.dtype(), "Expected the same dtype as start (input_a), for end (input_b)");
TT_FATAL(input_a.dtype() == input_c.dtype(), "Expected the same dtype as start (input_a), for weight (input_c)");
Tensor _lerp(const Tensor& input, const Tensor& end, const Tensor& weight, const std::optional<MemoryConfig>& output_mem_config) {
TT_FATAL(input.dtype() == end.dtype(), "Expected the same dtype as start (input), for end");
TT_FATAL(input.dtype() == weight.dtype(), "Expected the same dtype as start (input), for weight");
Tensor t_diff = ttnn::multiply(
ttnn::subtract(input_b, input_a, std::nullopt, output_mem_config), input_c, std::nullopt, output_mem_config);
Tensor result = ttnn::add(input_a, t_diff, std::nullopt, output_mem_config);
ttnn::subtract(end, input, std::nullopt, output_mem_config), weight, std::nullopt, output_mem_config);
Tensor result = ttnn::add(input, t_diff, std::nullopt, output_mem_config);
return result;
}

Expand Down
40 changes: 20 additions & 20 deletions ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ void bind_ternary_lerp(py::module& module, const ternary_operation_t& operation,
{2}
.. math::
\mathrm{{output\_tensor}} = \verb|{0}|(\mathrm{{input\_tensor\_a, input\_tensor\_b, input\_tensor\_c}})
\mathrm{{output\_tensor}} = \verb|{0}|(\mathrm{{input, end, weight}})
Args:
input_tensor_a (ttnn.Tensor): the input tensor.
input_tensor_b (ttnn.Tensor): the input tensor.
input_tensor_c (ttnn.Tensor or Number): the input tensor.
input (ttnn.Tensor): the input tensor with the starting points.
end (ttnn.Tensor): the tensor with the ending points.
weight (ttnn.Tensor or float): the weight for the interpolation formula.
Keyword Args:
Expand All @@ -240,7 +240,7 @@ void bind_ternary_lerp(py::module& module, const ternary_operation_t& operation,
bfloat8_b/bfloat4_b supports only on TILE_LAYOUT
input_tensor_b, input_tensor_c should have same dtype as input_tensor_a
end, weight tensors should have same dtype as input
Example:
>>> tensor1 = ttnn.from_torch(torch.tensor([[1, 0], [1, 0]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
Expand All @@ -258,29 +258,29 @@ void bind_ternary_lerp(py::module& module, const ternary_operation_t& operation,
doc,
ttnn::pybind_overload_t{
[](const ternary_operation_t& self,
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const Tensor& input_tensor_c,
const Tensor& input,
const Tensor& end,
const Tensor& weight,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor_a, input_tensor_b, input_tensor_c, memory_config);
return self(input, end, weight, memory_config);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::arg("input_tensor_c"),
py::arg("input"),
py::arg("end"),
py::arg("weight"),
py::kw_only(),
py::arg("memory_config") = std::nullopt},

ttnn::pybind_overload_t{
[](const ternary_operation_t& self,
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
float value,
const Tensor& input,
const Tensor& end,
float weight,
const std::optional<MemoryConfig>& memory_config) {
return self(input_tensor_a, input_tensor_b, value, memory_config);
return self(input, end, weight, memory_config);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::arg("value"),
py::arg("input"),
py::arg("end"),
py::arg("weight"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}
Expand Down Expand Up @@ -383,7 +383,7 @@ void py_module(py::module& module) {
detail::bind_ternary_lerp(
module,
ttnn::lerp,
R"doc(Computes Lerp on :attr:`input_tensor_a`, :attr:`input_tensor_b` and :attr:`input_tensor_c` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc");
R"doc(Computes Lerp on :attr:`input`, :attr:`end` and :attr:`weight` and returns the tensor with the same layout as :attr:`input`)doc");

detail::bind_ternary_mac(
module,
Expand Down

0 comments on commit 59c4891

Please sign in to comment.