Skip to content

Commit

Permalink
#10382: Migrate polar to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 18, 2024
1 parent 51aa633 commit eae7837
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 25 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ Pointwise Unary
ttnn/conj_bw
ttnn/conj
ttnn/complex_recip
ttmm/polar

Pointwise Binary
================
Expand Down
4 changes: 0 additions & 4 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,6 @@ Complex arithmetic can be carried out for multiply, divide, add and subtract as

.. autofunction:: tt_lib.tensor.complex_div

and then unary operations for,

.. autofunction:: tt_lib.tensor.polar

Complex Operations (Type 2)
===========================
Type 2 Complex representation allows for more flexible storage than earlier one while providing same set of
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/polar.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.polar:

ttnn.polar
###############

.. autofunction:: ttnn.polar
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def test_level2_polar(bs, memcfg, dtype, device, function_level_defaults):
ttl.tensor.Tensor(x.real, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
ttl.tensor.Tensor(x.imag, dtype).to(ttl.tensor.Layout.TILE).to(device, memcfg),
)
tt_dev = ttl.tensor.polar(xtt, memcfg)
tt_dev = ttnn.polar(xtt, memory_config=memcfg)
tt_dev_real = tt_dev.real.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_dev_imag = tt_dev.imag.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_cpu = torch.polar(x.real, x.imag)
Expand Down
14 changes: 0 additions & 14 deletions tt_eager/tt_dnn/op_library/complex/complex_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,6 @@ Tensor polar(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& o
return mk_complex( r, i, output_mem_config);
}

ComplexTensor polar(const ComplexTensor& input, const MemoryConfig& output_mem_config) {
const Tensor& input_a = input.real();
const Tensor& input_b = input.imag();
Tensor c = ttnn::cos(input_b,output_mem_config);
Tensor r = ttnn::multiply(input_a,c,std::nullopt,output_mem_config);
c.deallocate();

Tensor s = ttnn::sin(input_b,output_mem_config);
Tensor i = ttnn::multiply(input_a,s,std::nullopt,output_mem_config);
s.deallocate();

return ComplexTensor({r,i});
}

// backward ops for type2 complex tensor

// complex add
Expand Down
1 change: 0 additions & 1 deletion tt_eager/tt_dnn/op_library/complex/complex_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ ComplexTensor complex_div(const ComplexTensor& input_a, const ComplexTensor& inp

//polar operator: return a complex value tensor
Tensor polar(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
ComplexTensor polar(const ComplexTensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

//backward
std::vector<ComplexTensor> complex_add_bw(const ComplexTensor& grad, const ComplexTensor& input, const ComplexTensor& other, float alpha = 1.0, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
Expand Down
4 changes: 0 additions & 4 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1521,10 +1521,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
py::arg("input_a"), py::arg("input_b"),
py::arg("output_mem_config").noconvert() = std::nullopt,R"doc(Perform an eltwise-binary subtraction ``input_a - input_b`` on two complex tensors.)doc");

m_tensor.def("polar", py::overload_cast<const ComplexTensor&, const MemoryConfig&>(&tt::tt_metal::polar),
py::arg("input_a"),
py::arg("output_mem_config").noconvert() = std::nullopt,R"doc(Perform an polar to Cartesian transformation of the input.real(r), input.imag(theta) into x + i*y generating a type-2 complex tensor.)doc");

m_tensor.def("polar", py::overload_cast<const Tensor&,const Tensor&, const MemoryConfig&>(&tt::tt_metal::polar),
py::arg("input_a"), py::arg("input_b"),
py::arg("output_mem_config").noconvert() = std::nullopt,R"doc(Perform an polar to Cartesian transformation of the input_a (r), input_b(theta) into x + i*y generating a type-2 complex tensor.)doc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,6 @@ constexpr auto is_real = ttnn::register_operation<operations::complex_unary::Exe
//OpHandler_complex_type2 = get_function_complex_unary_type2 --> ComplexTensor return type
constexpr auto conj = ttnn::register_operation<operations::complex_unary::ExecuteComplexUnaryType2<operations::complex_unary::ComplexUnaryOpType::CONJ>>("ttnn::conj");
constexpr auto complex_recip = ttnn::register_operation<operations::complex_unary::ExecuteComplexUnaryType2<operations::complex_unary::ComplexUnaryOpType::COMPLEX_RECIP>>("ttnn::complex_recip");
constexpr auto polar = ttnn::register_operation<operations::complex_unary::ExecuteComplexUnaryType2<operations::complex_unary::ComplexUnaryOpType::POLAR>>("ttnn::polar");

}
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ void py_module(py::module& module) {
ttnn::complex_recip,
R"doc(Returns complex reciprocal value of complex tensor :attr:`input_tensor`.)doc");

detail::bind_complex_unary_type2(
module,
ttnn::polar,
R"doc(Perform an polar to Cartesian transformation on :attr:`input_tensor`, input_tensor.real(r), input_tensor.imag(theta) into x + i*y generating a complex tensor.)doc");

}

} // namespace complex_unary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,18 @@ ComplexTensor _complex_recip(const ComplexTensor& input, const MemoryConfig& out
return ComplexTensor({ conj_re, conj_im});
}

ComplexTensor _polar(const ComplexTensor& input, const MemoryConfig& output_mem_config) {
const Tensor& input_a = input.real();
const Tensor& input_b = input.imag();
Tensor c = ttnn::cos(input_b,output_mem_config);
Tensor r = ttnn::multiply(input_a,c,std::nullopt,output_mem_config);
c.deallocate();

Tensor s = ttnn::sin(input_b,output_mem_config);
Tensor i = ttnn::multiply(input_a,s,std::nullopt,output_mem_config);
s.deallocate();

return ComplexTensor({r,i});
}

} // namespace ttnn::operations::complex_unary
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum class ComplexUnaryOpType {
ABS,
CONJ,
COMPLEX_RECIP,
POLAR,
};

//OpHandler_complex_type1 = get_function_complex_unary --> Tensor return type
Expand All @@ -35,6 +36,7 @@ Tensor _abs(const ComplexTensor& input, const MemoryConfig& output_mem_config);
//OpHandler_complex_type2 = get_function_complex_unary_type2 --> ComplexTensor return type
ComplexTensor _conj(const ComplexTensor& input, const MemoryConfig& output_mem_config);
ComplexTensor _complex_recip(const ComplexTensor& input, const MemoryConfig& output_mem_config);
ComplexTensor _polar(const ComplexTensor& input, const MemoryConfig& output_mem_config);

template <ComplexUnaryOpType OpType>
struct OpHandler_complex_type1;
Expand Down Expand Up @@ -98,6 +100,13 @@ struct OpHandler_complex_type2<ComplexUnaryOpType::COMPLEX_RECIP> {
}
};

template <>
struct OpHandler_complex_type2<ComplexUnaryOpType::POLAR> {
static ComplexTensor handle( const ComplexTensor& input, const MemoryConfig& output_mem_config ) {
return _polar(input, output_mem_config);
}
};

template <ComplexUnaryOpType OpType>
auto get_function_complex_unary() {
return &OpHandler_complex_type1<OpType>::handle;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace ttnn::operations::complex_unary_backward {
// polar fwd op uses sin and cos hence input_b range is (0, 2*pi)
std::vector<ComplexTensor> _polar_bw(const ComplexTensor& grad, const ComplexTensor& input, const MemoryConfig& output_mem_config) {
std::vector<ComplexTensor> grad_tensor;
ComplexTensor result = polar(input, output_mem_config);
ComplexTensor result = ttnn::operations::complex_unary::_polar(input, output_mem_config);
Tensor abs_result = ttnn::operations::complex_unary::_abs(result, output_mem_config);
Tensor sgn_result_r = where(ttnn::eqz(abs_result, output_mem_config), ttnn::operations::creation::zeros_like(result.real(), result.real().get_dtype(), result.real().get_layout(), std::nullopt, output_mem_config), ttnn::multiply(result.real(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config );
Tensor sgn_result_i = where(ttnn::eqz(abs_result, output_mem_config), ttnn::operations::creation::zeros_like(result.imag(), result.imag().get_dtype(), result.imag().get_layout(), std::nullopt, output_mem_config), ttnn::multiply(result.imag(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config );
Expand Down

0 comments on commit eae7837

Please sign in to comment.