From d0c11b1dee92810e3cdb4a72acaf7ad4d924e2e6 Mon Sep 17 00:00:00 2001 From: mouliraj-mcw Date: Tue, 30 Jul 2024 12:19:50 +0000 Subject: [PATCH] #10890: Add rounding mode support for rdiv op --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 5 -- .../sweep_tests/tt_lib_ops.py | 4 +- .../unit_tests/operations/test_composite.py | 25 +++++++ .../op_library/composite/composite_ops.cpp | 24 ------- .../op_library/composite/composite_ops.hpp | 10 --- .../tt_lib_bindings_tensor_composite_ops.cpp | 34 ---------- .../unary/device/unary_composite_op.cpp | 8 ++- .../unary/device/unary_composite_op.hpp | 9 --- .../eltwise/unary/unary_composite.hpp | 5 +- .../operations/eltwise/unary/unary_pybind.hpp | 67 ++++++++++++++++++- ttnn/ttnn/operations/unary.py | 12 ++++ 11 files changed, 114 insertions(+), 89 deletions(-) diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 5c63e2c34bf..c0141575436 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -256,13 +256,8 @@ Enums Tensor elementwise operations ============================= -.. autofunction:: tt_lib.tensor.unary_rdiv_trunc - .. autofunction:: tt_lib.tensor.assign -.. autofunction:: tt_lib.tensor.rfloor_div - - Tensor manipulation operations -============================= diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index be739649dfc..b4890dd47b2 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -856,7 +856,7 @@ def eltwise_rfloor_div( **kwargs, ): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.rfloor_div(value, t0, output_mem_config=output_mem_config) + t1 = ttnn.rdiv(t0, value, round_mode="floor", memory_config=output_mem_config) return tt2torch_tensor(t1) @@ -929,7 +929,7 @@ def eltwise_unary_rdiv_trunc( **kwargs, ): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.unary_rdiv_trunc(value, t0, output_mem_config=output_mem_config) + t1 = ttnn.rdiv(t0, value, round_mode="trunc", memory_config=output_mem_config) return tt2torch_tensor(t1) diff --git a/tests/ttnn/unit_tests/operations/test_composite.py b/tests/ttnn/unit_tests/operations/test_composite.py index 00d8749bad5..3e6810f1d67 100644 --- a/tests/ttnn/unit_tests/operations/test_composite.py +++ b/tests/ttnn/unit_tests/operations/test_composite.py @@ -764,3 +764,28 @@ def test_unary_celu(input_shapes, param, device): comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "param", + {random.uniform(0, 100) for _ in range(5)}, +) +@pytest.mark.parametrize("round_mode", ["None", "trunc", "floor"]) +def test_unary_rdiv(input_shapes, param, round_mode, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device) + + output_tensor = ttnn.rdiv(input_tensor, param, round_mode=round_mode) + golden_function = ttnn.get_golden_function(ttnn.rdiv) + golden_tensor = golden_function(in_data, param, round_mode=round_mode) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp index 24e94ab099b..8aad2deeb05 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp @@ -29,30 +29,6 @@ namespace tt { namespace tt_metal { -Tensor _unary_rdiv_trunc( - float value, - const Tensor& input, - const MemoryConfig& output_mem_config) { - auto arch = input.device()->arch(); - TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole"); - Tensor result = ttnn::multiply(ttnn::full_like(input, value), ttnn::reciprocal(input)); - return ttnn::trunc(result); -} -Tensor unary_rdiv_trunc( - float value, - const Tensor& input, - const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _unary_rdiv_trunc)(value, input, output_mem_config); -} - -Tensor _rfloor_div(float value, const Tensor& input, const MemoryConfig& output_mem_config) { - Tensor result = ttnn::multiply(ttnn::full_like(input, value), ttnn::reciprocal(input)); - return ttnn::floor(result, output_mem_config); -} -Tensor rfloor_div(float value, const Tensor& input, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _rfloor_div)(value, input, output_mem_config); -} - } // namespace tt_metal } // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp index 2ea2b4030cd..96499d21838 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp @@ -27,16 +27,6 @@ using binary_tensor_op_t = Tensor(const Tensor& a, const Tensor& b); // Note: inline doesn't allow pybind to work well so we keep few function not inlined. -Tensor unary_rdiv_trunc( - float value, - const Tensor& input, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - -Tensor rfloor_div( - float value, - const Tensor& input, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - } // namespace tt_metal } // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index 3c097020fa7..eece3882969 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -75,40 +75,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) { )doc"); #endif - m_tensor.def("unary_rdiv_trunc", py::overload_cast(&unary_rdiv_trunc), - py::arg("value").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( - Performs the element-wise division of a scalar ``value`` by a tensor ``input`` and rounds the result using trunc mode. Support provided only for Wormhole_B0. - - Input tensor must have BFLOAT16 data type. - - Output tensor will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "value", "Numerator value", "float", "", "Yes" - "input", "Denominator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - - - m_tensor.def("rfloor_div", py::overload_cast(&rfloor_div), - py::arg("value").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc( - Performs the element-wise floor division of a scalar ``value`` by a tensor ``input``. Support provided only for Wormhole_B0. - - Input tensor must have BFLOAT16 data type. - - Output tensor will have BFLOAT16 data type. - - .. csv-table:: - :header: "Argument", "Description", "Data type", "Valid range", "Required" - - "value", "Numerator value", "float", "", "Yes" - "input", "Denominator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" - "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" - )doc"); - - m_tensor.def( "lamb_optimizer", &lamb_optimizer, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index de0942c698d..d87bad0f779 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -666,11 +666,17 @@ Tensor _polygamma(const Tensor& input_a, int32_t k, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { +Tensor ExecuteRdiv::operator()(uint8_t queue_id, const Tensor& input_tensor, float value, string round_mode, const std::optional& memory_config, std::optional optional_output_tensor) { float t_inf = std::numeric_limits::infinity(); Tensor recip_result = ttnn::reciprocal(queue_id, input_tensor, memory_config, optional_output_tensor); Tensor result = ttnn::multiply(queue_id, recip_result, value, std::nullopt, memory_config, optional_output_tensor); + if(round_mode == "trunc"){ + result = trunc(result); + } + else if(round_mode == "floor"){ + result = floor(result); + } return ttnn::where(ttnn::eqz(queue_id, input_tensor, memory_config), t_inf, result, memory_config, optional_output_tensor); } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp index 9abb6b104fb..3dd905d8480 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp @@ -46,7 +46,6 @@ enum class UnaryCompositeOpType { GEGLU, SWIGLU, POW, - RDIV, TRIL, TRIU, ROUND, @@ -99,7 +98,6 @@ Tensor _power(uint8_t, const Tensor&, float, const std::optional&, Tensor _power(uint8_t, const Tensor&, uint32_t, const std::optional&, std::optional); Tensor _tril(const Tensor&, int32_t diag = 0, const std::optional& output_mem_config = std::nullopt); Tensor _triu(const Tensor&, int32_t diag = 0, const std::optional& output_mem_config = std::nullopt); -Tensor _rdiv(uint8_t, const Tensor&, float, const std::optional&, std::optional); Tensor _round(const Tensor&, int32_t decimal =0 , const std::optional& output_mem_config = std::nullopt); Tensor _polygamma(const Tensor&, int32_t, const std::optional& ); Tensor _hardshrink(const Tensor& a, float lambd = 0.5f, const std::optional& output_mem_config = std::nullopt); @@ -378,13 +376,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static Tensor handle(uint8_t queue_id, const Tensor& input_tensor, float value, const std::optional& memory_config, std::optional optional_output_tensor){ - return _rdiv(queue_id, input_tensor, value, memory_config, optional_output_tensor); - } -}; - template <> struct OpHandler { static Tensor handle(const Tensor& t1, float lambd, const std::optional& mem_cfg ) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp index 4673efabbbc..b90979cb962 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp @@ -119,10 +119,9 @@ struct ExecuteRdiv { uint8_t queue_id, const Tensor& input_tensor, float value, + string round_mode = "None", const std::optional& memory_config = std::nullopt, - std::optional optional_output_tensor = std::nullopt) { - return OpHandler::handle(queue_id, input_tensor, value, memory_config.value_or(input_tensor.memory_config()), optional_output_tensor); - } + std::optional optional_output_tensor = std::nullopt); }; } // namespace unary diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 071d8214c3a..8fd65ab341d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -408,6 +408,59 @@ void bind_unary_operation_with_dim_parameter( py::arg("memory_config") = std::nullopt}); } +template +void bind_unary_rdiv(py::module& module, const unary_operation_t& operation, const std::string& parameter_name_a, const std::string& parameter_a_doc, const std::string& parameter_name_b, const std::string& parameter_b_doc, const std::string parameter_b_value, const std::string& description) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, {2}: float, *, {4}: string, memory_config: ttnn.MemoryConfig) -> std::vector + + {7} + + Args: + * :attr:`input_tensor` + * :attr:`{2}` (float): {3} + + Keyword args: + * :attr:`{4}` (string): {5} , Default value = {6} + * :attr:`memory_config` [ttnn.MemoryConfig]: memory config for the output tensor + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor, {2}, {4} = {6}) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + parameter_name_a, + parameter_a_doc, + parameter_name_b, + parameter_b_doc, + parameter_b_value, + description); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const ttnn::Tensor& input_tensor, + float parameter_a, + string parameter_b, + const std::optional& memory_config, + const std::optional& output_tensor, + const uint8_t& queue_id) { + return self(queue_id, input_tensor, parameter_a, parameter_b, memory_config, output_tensor); + }, + py::arg("input_tensor"), + py::arg(parameter_name_a.c_str()), + py::kw_only(), + py::arg(parameter_name_b.c_str()) = parameter_b_value, + py::arg("memory_config") = std::nullopt, + py::arg("output_tensor") = std::nullopt, + py::arg("queue_id") = 0}); +} template void bind_softplus(py::module& module, const unary_operation_t& operation) { @@ -459,6 +512,7 @@ void bind_softplus(py::module& module, const unary_operation_t& operation) { py::arg("output_tensor") = std::nullopt, py::arg("queue_id") = 0}); } + template void bind_sigmoid_accurate(py::module& module, const unary_operation_t& operation) { auto doc = fmt::format( @@ -1222,7 +1276,6 @@ void py_module(py::module& module) { // Unaries with float parameter detail::bind_unary_operation_with_float_parameter(module, ttnn::elu, "alpha", "The alpha parameter for the ELU function", ""); detail::bind_unary_operation_with_float_parameter(module, ttnn::rsub, "value", "subtrahent value which is actually calculated as minuend", "Returns tensor with respective elements of the input tensor subtracted from the value."); - detail::bind_unary_operation_with_float_parameter(module, ttnn::rdiv, "value", "denominator value which is actually calculated as numerator float value >= 0 ", "Returns tensor with scalar value divided by each of respective elements of the input tensor."); detail::bind_unary_operation_with_float_parameter(module, ttnn::heaviside, "value", "The value parameter for the Heaviside function", ""); detail::bind_unary_operation_with_float_parameter(module, ttnn::leaky_relu, "slope", "The slope parameter for the Leaky ReLU function", ""); detail::bind_unary_operation_with_float_parameter(module, ttnn::relu_max, "upper_limit", "The max value for ReLU function", "This function caps off the input to a max value and a min value of 0"); @@ -1366,6 +1419,18 @@ void py_module(py::module& module) { ttnn::rpow, "exponent", "exponent value", R"doc(Performs rpow function on :attr:`input_tensor`, :attr:`exponent`.)doc"); + + detail::bind_unary_rdiv( + module, + ttnn::rdiv, + "value", "denominator value which is actually calculated as numerator float value >= 0", + "round_mode", "rounding_mode value", "None", + R"doc(Performs the element-wise division of a scalar ``value`` by a tensor ``input`` and rounds the result using round_mode. Support provided only for Wormhole_B0. + + Input tensor must have BFLOAT16 data type. + + Output tensor will have BFLOAT16 data type.)doc"); + } } // namespace unary diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 58e5ca2f15f..3da87d669be 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -638,4 +638,16 @@ def _golden_function_frac(input_tensor_a, *args, **kwargs): ttnn.attach_golden_function(ttnn._ttnn.operations.unary.frac, golden_function=_golden_function_frac) + + +def _golden_function_rdiv(input_tensor_a, value, *args, round_mode=None, **kwargs): + import torch + + if round_mode == "None": + round_mode = None + + return torch.div(torch.full_like(input_tensor_a, value), input_tensor_a, rounding_mode=round_mode) + + +ttnn.attach_golden_function(ttnn._ttnn.operations.unary.rdiv, golden_function=_golden_function_rdiv) __all__ = []