Skip to content

Commit

Permalink
Revert "#11043: Overload complex forward ops" (#11091)
Browse files Browse the repository at this point in the history
This reverts commit 8d4008d.
  • Loading branch information
patrickroberts authored Aug 5, 2024
1 parent ec0cc14 commit b8dde20
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 132 deletions.
2 changes: 2 additions & 0 deletions tests/ttnn/unit_tests/operations/test_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def test_level2_sub(bs, memcfg, dtype, device, function_level_defaults):
assert passing


@pytest.mark.skip(reason="This test is failing because ttnn.mul doesn't support complex tensors")
@pytest.mark.parametrize(
"memcfg",
(
Expand Down Expand Up @@ -369,6 +370,7 @@ def test_level2_mul(bs, memcfg, dtype, device, function_level_defaults):
assert passing


@pytest.mark.skip(reason="This test is failing because ttnn.div doesn't support complex tensors")
@pytest.mark.parametrize(
"memcfg",
(
Expand Down
25 changes: 0 additions & 25 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/data_movement.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp"

namespace ttnn::operations::binary {

Expand Down Expand Up @@ -219,30 +218,6 @@ Tensor BinaryOperation<binary_op_type, in_place>::operator()(
}


template <BinaryOpType binary_op_type, bool in_place>
ComplexTensor BinaryOperation<binary_op_type, in_place>::operator()(
const ComplexTensor &input_a,
const ComplexTensor &input_b,
const ttnn::MemoryConfig &output_mem_config) {
if constexpr(binary_op_type == BinaryOpType::MUL) {
Tensor re_part = ttnn::subtract(
ttnn::multiply(input_a[0],input_b[0],std::nullopt,output_mem_config),
ttnn::multiply(input_a[1],input_b[1],std::nullopt,output_mem_config),
std::nullopt, output_mem_config);

Tensor im_part = ttnn::add(
ttnn::multiply(input_a[0],input_b[1],std::nullopt,output_mem_config),
ttnn::multiply(input_a[1],input_b[0],std::nullopt,output_mem_config),
std::nullopt, output_mem_config);

return ComplexTensor({ re_part, im_part });
}else if constexpr(binary_op_type == BinaryOpType::DIV_FAST) {
return ttnn::multiply( input_a, ttnn::operations::complex_unary::_reciprocal( input_b , output_mem_config ), output_mem_config ); //TODO: Overload reciprocal
}else {
TT_THROW("Unsupported operation (expected MUL or DIV_FAST)");
}
}

template <BinaryOpType binary_op_type, bool in_place>
Tensor RelationalBinary<binary_op_type, in_place>::operator()(
uint8_t queue_id,
Expand Down
11 changes: 2 additions & 9 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "ttnn/decorators.hpp"
#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp"
#include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp"
#include "ttnn/operations/eltwise/complex/complex.hpp"

namespace ttnn {

Expand Down Expand Up @@ -59,12 +58,6 @@ struct BinaryOperation {
const std::optional<Tensor> &optional_output_tensor = std::nullopt,
std::optional<unary::FusedActivations> activations = std::nullopt,
std::optional<unary::UnaryWithParam> input_tensor_a_activation = std::nullopt);

static ComplexTensor operator()(
const ComplexTensor &input_tensor_a_arg,
const ComplexTensor &input_tensor_b_arg,
const MemoryConfig &memory_config);

};

template <BinaryOpType binary_op_type, bool in_place>
Expand Down Expand Up @@ -132,7 +125,7 @@ constexpr auto subtract = ttnn::register_operation_with_auto_launch_op<
constexpr auto subtract_ = ttnn::register_operation_with_auto_launch_op<
"ttnn::subtract_",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::SUB, true>>();
constexpr auto multiply = ttnn::register_operation<
constexpr auto multiply = ttnn::register_operation_with_auto_launch_op<
"ttnn::multiply",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::MUL, false>>();
constexpr auto multiply_ = ttnn::register_operation_with_auto_launch_op<
Expand Down Expand Up @@ -176,7 +169,7 @@ constexpr auto logaddexp2 = ttnn::register_operation_with_auto_launch_op<
constexpr auto squared_difference = ttnn::register_operation_with_auto_launch_op<
"ttnn::squared_difference",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::SQUARED_DIFFERENCE, false>>();
constexpr auto divide = ttnn::register_operation<
constexpr auto divide = ttnn::register_operation_with_auto_launch_op<
"ttnn::divide",
operations::binary::BinaryOperation<operations::binary::BinaryOpType::DIV_FAST, false>>();

Expand Down
103 changes: 5 additions & 98 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,99 +104,6 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati
py::arg("queue_id") = 0});
}

template <typename binary_operation_t>
void bind_binary_operation_with_complex(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor], input_tensor_b: Union[ttnn.Tensor, ComplexTensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None, activations: Optional[List[str]] = None) -> ttnn.Tensor or ComplexTensor
{2}
Supports broadcasting.
Args:
* :attr:`input_tensor_a` (ComplexTensor or ttnn.Tensor)
* :attr:`input_tensor_b` (ComplexTensor or ttnn.Tensor or Number)
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
* :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
Example:
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
>>> output = {1}(tensor1, tensor2)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description);

bind_registered_operation(
module,
operation,
doc,
// tensor and scalar
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const ttnn::Tensor& input_tensor_a,
const float scalar,
const std::optional<const DataType>& dtype,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const std::optional<unary::FusedActivations>& activations,
const std::optional<unary::UnaryWithParam>& input_tensor_a_activation,
const uint8_t& queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor_a, scalar, dtype, memory_config, output_tensor, activations, input_tensor_a_activation);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("activations") = std::nullopt,
py::arg("input_tensor_a_activation") = std::nullopt,
py::arg("queue_id") = 0},

// tensor and tensor
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const DataType>& dtype,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const std::optional<unary::FusedActivations>& activations,
const std::optional<unary::UnaryWithParam>& input_tensor_a_activation,
const uint8_t& queue_id) -> ttnn::Tensor {
return self(queue_id, input_tensor_a, input_tensor_b, dtype, memory_config, output_tensor, activations, input_tensor_a_activation);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("dtype") = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("activations") = std::nullopt,
py::arg("input_tensor_a_activation") = std::nullopt,
py::arg("queue_id") = 0},

// complex tensor
ttnn::pybind_overload_t{
[](const binary_operation_t& self,
const ComplexTensor& input_tensor_a,
const ComplexTensor& input_tensor_b,
const MemoryConfig& memory_config) {
return self(input_tensor_a, input_tensor_b, memory_config);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("memory_config")});
}

template <typename binary_operation_t>
void bind_binary_composite(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
Expand Down Expand Up @@ -380,11 +287,11 @@ void bind_div_like_ops(py::module& module, const binary_operation_t& operation,
template <typename binary_operation_t>
void bind_div(py::module& module, const binary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor], input_tensor_b: Union[ttnn.Tensor, ComplexTensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
R"doc({0}(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Args:
* :attr:`input_tensor_a` (ttnn.Tensor or ComplexTensor)
* :attr:`input_tensor_b` (ttnn.Tensor or Number or ComplexTensor)
* :attr:`input_tensor_a`
* :attr:`input_tensor_b` (ttnn.Tensor or Number)
* :attr:`accurate_mode`: ``false`` if input_tensor_b is non-zero, else ``true``.
* :attr:`round_mode`
Expand Down Expand Up @@ -557,7 +464,7 @@ void py_module(py::module& module) {
R"doc(Subtracts :attr:`input_tensor_b` from :attr:`input_tensor_a` and returns the tensor with the same layout as :attr:`input_tensor_a` in-place
.. math:: \mathrm{{input\_tensor\_a}}_i - \mathrm{{input\_tensor\_b}}_i)doc");

detail::bind_binary_operation_with_complex(
detail::bind_binary_operation(
module,
ttnn::multiply,
R"doc(Multiplies :attr:`input_tensor_a` by :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`
Expand Down Expand Up @@ -647,7 +554,7 @@ void py_module(py::module& module) {
R"doc(Compute bias_gelu of :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`
.. math:: \mathrm{{input\_tensor\_a}}_i || \mathrm{{input\_tensor\_b}}_i)doc");

detail::bind_binary_operation_with_complex(
detail::bind_binary_operation(
module,
ttnn::divide,
R"doc(Divides :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`
Expand Down

0 comments on commit b8dde20

Please sign in to comment.