Skip to content

Commit

Permalink
#10382: Overload complex_recip with reciprocal
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 18, 2024
1 parent eae7837 commit 9e32d67
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ Pointwise Unary
ttnn/angle_bw
ttnn/conj_bw
ttnn/conj
ttnn/complex_recip
ttnn/reciprocal
ttmm/polar

Pointwise Binary
Expand Down
6 changes: 0 additions & 6 deletions docs/source/ttnn/ttnn/ttnn/complex_recip.rst

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def test_level2_recip(bs, memcfg, dtype, device, function_level_defaults):
ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
tt_dev = ttnn.complex_recip(xtt, memory_config=memcfg)
tt_dev = ttnn.reciprocal(xtt, memory_config=memcfg)
tt_dev_r = tt_dev.real.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_dev_i = tt_dev.imag.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_dev = Complex(re=tt_dev_r, im=tt_dev_i).metal
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/complex/complex_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ ComplexTensor complex_mul(const ComplexTensor& ab, const ComplexTensor& cd, con
}

ComplexTensor complex_div(const ComplexTensor& input_a, const ComplexTensor& input_b, const MemoryConfig& output_mem_config) {
return complex_mul( input_a, ttnn::operations::complex_unary::_complex_recip( input_b , output_mem_config ), output_mem_config );
return complex_mul( input_a, ttnn::operations::complex_unary::_reciprocal( input_b , output_mem_config ), output_mem_config );
}

ComplexTensor complex_add(const ComplexTensor& input_a, const ComplexTensor& input_b, const MemoryConfig& output_mem_config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ constexpr auto is_real = ttnn::register_operation<operations::complex_unary::Exe

//OpHandler_complex_type2 = get_function_complex_unary_type2 --> ComplexTensor return type
constexpr auto conj = ttnn::register_operation<operations::complex_unary::ExecuteComplexUnaryType2<operations::complex_unary::ComplexUnaryOpType::CONJ>>("ttnn::conj");
constexpr auto complex_recip = ttnn::register_operation<operations::complex_unary::ExecuteComplexUnaryType2<operations::complex_unary::ComplexUnaryOpType::COMPLEX_RECIP>>("ttnn::complex_recip");
constexpr auto polar = ttnn::register_operation<operations::complex_unary::ExecuteComplexUnaryType2<operations::complex_unary::ComplexUnaryOpType::POLAR>>("ttnn::polar");

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ void py_module(py::module& module) {
ttnn::conj,
R"doc(Returns complex conjugate value of complex tensor :attr:`input_tensor`.)doc");

detail::bind_complex_unary_type2(
module,
ttnn::complex_recip,
R"doc(Returns complex reciprocal value of complex tensor :attr:`input_tensor`.)doc");

detail::bind_complex_unary_type2(
module,
ttnn::polar,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ComplexTensor _conj(const ComplexTensor& input, const MemoryConfig& output_mem_c
return ComplexTensor({input[0], ttnn::neg(input[1],output_mem_config)});
}

ComplexTensor _complex_recip(const ComplexTensor& input, const MemoryConfig& output_mem_config) {
ComplexTensor _reciprocal(const ComplexTensor& input, const MemoryConfig& output_mem_config) {
Tensor a_plus_b = ttnn::add(input[0],input[1],std::nullopt,output_mem_config);
Tensor a_minus_b = ttnn::subtract(input[0],input[1],std::nullopt,output_mem_config);
Tensor asqr_plus_bsqr = ttnn::add(ttnn::square(input[0],output_mem_config),ttnn::square(input[1],output_mem_config),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ enum class ComplexUnaryOpType {
IS_REAL,
ABS,
CONJ,
COMPLEX_RECIP,
RECIPROCAL,
POLAR,
};

Expand All @@ -35,7 +35,7 @@ Tensor _abs(const ComplexTensor& input, const MemoryConfig& output_mem_config);

//OpHandler_complex_type2 = get_function_complex_unary_type2 --> ComplexTensor return type
ComplexTensor _conj(const ComplexTensor& input, const MemoryConfig& output_mem_config);
ComplexTensor _complex_recip(const ComplexTensor& input, const MemoryConfig& output_mem_config);
ComplexTensor _reciprocal(const ComplexTensor& input, const MemoryConfig& output_mem_config);
ComplexTensor _polar(const ComplexTensor& input, const MemoryConfig& output_mem_config);

template <ComplexUnaryOpType OpType>
Expand Down Expand Up @@ -94,9 +94,9 @@ struct OpHandler_complex_type2<ComplexUnaryOpType::CONJ> {
};

template <>
struct OpHandler_complex_type2<ComplexUnaryOpType::COMPLEX_RECIP> {
struct OpHandler_complex_type2<ComplexUnaryOpType::RECIPROCAL> {
static ComplexTensor handle( const ComplexTensor& input, const MemoryConfig& output_mem_config ) {
return _complex_recip(input, output_mem_config);
return _reciprocal(input, output_mem_config);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ std::vector<ComplexTensor> _complex_recip_bw(const ComplexTensor& grad, const Co
std::vector<ComplexTensor> grad_tensor;
Tensor condition_nan = ttnn::logical_and(ttnn::eqz(input.real(),output_mem_config), ttnn::eqz(input.imag(),output_mem_config), std::nullopt, output_mem_config);
ComplexTensor neg_grad = ComplexTensor({ttnn::neg(grad.real(),output_mem_config), ttnn::neg(grad.imag(),output_mem_config)});
ComplexTensor inp_recip = ttnn::operations::complex_unary::_complex_recip(input, output_mem_config);
ComplexTensor inp_recip = ttnn::operations::complex_unary::_reciprocal(input, output_mem_config);
ComplexTensor grad_inp = complex_mul(neg_grad, ttnn::operations::complex_unary::_conj(complex_mul(inp_recip, inp_recip, output_mem_config), output_mem_config), output_mem_config) ;
neg_grad.deallocate();
inp_recip.deallocate();
Expand Down
61 changes: 60 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,65 @@ void bind_unary_operation_overload_complex(py::module& module, const unary_opera
py::arg("memory_config")});
}

template <typename unary_operation_t>
void bind_unary_operation_overload_complex_return_complex(py::module& module, const unary_operation_t& operation, const std::string& info_doc = "" ) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor or ComplexTensor, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
{2}
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
info_doc);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const uint8_t& queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor, memory_config, output_tensor);
},
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = 0},

ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const ComplexTensor& input_tensor,
const ttnn::MemoryConfig& memory_config) -> ComplexTensor {
using ComplexUnaryOp = ttnn::operations::complex_unary::ExecuteComplexUnaryType2<complex_unary::ComplexUnaryOpType::RECIPROCAL>;
return ComplexUnaryOp::execute_on_main_thread(input_tensor, memory_config);
},
py::arg("input_tensor"),
py::kw_only(),
py::arg("memory_config")});
}

template <typename unary_operation_t>
void bind_unary_operation_with_fast_and_approximate_mode(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
Expand Down Expand Up @@ -996,7 +1055,7 @@ void py_module(py::module& module) {
detail::bind_unary_operation(module, ttnn::ltz);
detail::bind_unary_operation(module, ttnn::neg);
detail::bind_unary_operation(module, ttnn::nez);
detail::bind_unary_operation(module, ttnn::reciprocal);
detail::bind_unary_operation_overload_complex_return_complex(module, ttnn::reciprocal);
detail::bind_unary_operation(module, ttnn::relu);
detail::bind_unary_operation(module, ttnn::relu6);
detail::bind_unary_operation(module, ttnn::sigmoid);
Expand Down

0 comments on commit 9e32d67

Please sign in to comment.