Skip to content

Commit

Permalink
#10890: Add rounding mode support for rdiv op
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Aug 10, 2024
1 parent 202a064 commit d0c11b1
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 89 deletions.
5 changes: 0 additions & 5 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,8 @@ Enums
Tensor elementwise operations
=============================

.. autofunction:: tt_lib.tensor.unary_rdiv_trunc

.. autofunction:: tt_lib.tensor.assign

.. autofunction:: tt_lib.tensor.rfloor_div


Tensor manipulation operations
-=============================

Expand Down
4 changes: 2 additions & 2 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def eltwise_rfloor_div(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.rfloor_div(value, t0, output_mem_config=output_mem_config)
t1 = ttnn.rdiv(t0, value, round_mode="floor", memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down Expand Up @@ -929,7 +929,7 @@ def eltwise_unary_rdiv_trunc(
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.unary_rdiv_trunc(value, t0, output_mem_config=output_mem_config)
t1 = ttnn.rdiv(t0, value, round_mode="trunc", memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down
25 changes: 25 additions & 0 deletions tests/ttnn/unit_tests/operations/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,3 +764,28 @@ def test_unary_celu(input_shapes, param, device):

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@skip_for_grayskull()
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"param",
{random.uniform(0, 100) for _ in range(5)},
)
@pytest.mark.parametrize("round_mode", ["None", "trunc", "floor"])
def test_unary_rdiv(input_shapes, param, round_mode, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device)

output_tensor = ttnn.rdiv(input_tensor, param, round_mode=round_mode)
golden_function = ttnn.get_golden_function(ttnn.rdiv)
golden_tensor = golden_function(in_data, param, round_mode=round_mode)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,6 @@ namespace tt {

namespace tt_metal {

Tensor _unary_rdiv_trunc(
float value,
const Tensor& input,
const MemoryConfig& output_mem_config) {
auto arch = input.device()->arch();
TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole");
Tensor result = ttnn::multiply(ttnn::full_like(input, value), ttnn::reciprocal(input));
return ttnn::trunc(result);
}
Tensor unary_rdiv_trunc(
float value,
const Tensor& input,
const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _unary_rdiv_trunc)(value, input, output_mem_config);
}

Tensor _rfloor_div(float value, const Tensor& input, const MemoryConfig& output_mem_config) {
Tensor result = ttnn::multiply(ttnn::full_like(input, value), ttnn::reciprocal(input));
return ttnn::floor(result, output_mem_config);
}
Tensor rfloor_div(float value, const Tensor& input, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _rfloor_div)(value, input, output_mem_config);
}

} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@ using binary_tensor_op_t = Tensor(const Tensor& a, const Tensor& b);
// Note: inline doesn't allow pybind to work well so we keep few function not inlined.


Tensor unary_rdiv_trunc(
float value,
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor rfloor_div(
float value,
const Tensor& input,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -75,40 +75,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
)doc");
#endif

m_tensor.def("unary_rdiv_trunc", py::overload_cast<float, const Tensor&, const MemoryConfig&>(&unary_rdiv_trunc),
py::arg("value").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs the element-wise division of a scalar ``value`` by a tensor ``input`` and rounds the result using trunc mode. Support provided only for Wormhole_B0.
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"value", "Numerator value", "float", "", "Yes"
"input", "Denominator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");


m_tensor.def("rfloor_div", py::overload_cast<float, const Tensor&, const MemoryConfig&>(&rfloor_div),
py::arg("value").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,R"doc(
Performs the element-wise floor division of a scalar ``value`` by a tensor ``input``. Support provided only for Wormhole_B0.
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"value", "Numerator value", "float", "", "Yes"
"input", "Denominator Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");


m_tensor.def(
"lamb_optimizer",
&lamb_optimizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,11 +666,17 @@ Tensor _polygamma(const Tensor& input_a, int32_t k, const std::optional<MemoryCo
}

//rdiv
Tensor _rdiv(uint8_t queue_id, const Tensor& input_tensor, float value, const std::optional<MemoryConfig>& memory_config = std::nullopt, std::optional<Tensor> optional_output_tensor = std::nullopt) {
Tensor ExecuteRdiv::operator()(uint8_t queue_id, const Tensor& input_tensor, float value, string round_mode, const std::optional<MemoryConfig>& memory_config, std::optional<Tensor> optional_output_tensor) {
float t_inf = std::numeric_limits<float>::infinity();
Tensor recip_result = ttnn::reciprocal(queue_id, input_tensor, memory_config, optional_output_tensor);
Tensor result = ttnn::multiply(queue_id, recip_result, value, std::nullopt, memory_config, optional_output_tensor);

if(round_mode == "trunc"){
result = trunc(result);
}
else if(round_mode == "floor"){
result = floor(result);
}
return ttnn::where(ttnn::eqz(queue_id, input_tensor, memory_config), t_inf, result, memory_config, optional_output_tensor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ enum class UnaryCompositeOpType {
GEGLU,
SWIGLU,
POW,
RDIV,
TRIL,
TRIU,
ROUND,
Expand Down Expand Up @@ -99,7 +98,6 @@ Tensor _power(uint8_t, const Tensor&, float, const std::optional<MemoryConfig>&,
Tensor _power(uint8_t, const Tensor&, uint32_t, const std::optional<MemoryConfig>&, std::optional<Tensor>);
Tensor _tril(const Tensor&, int32_t diag = 0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _triu(const Tensor&, int32_t diag = 0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _rdiv(uint8_t, const Tensor&, float, const std::optional<MemoryConfig>&, std::optional<Tensor>);
Tensor _round(const Tensor&, int32_t decimal =0 , const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _polygamma(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Tensor _hardshrink(const Tensor& a, float lambd = 0.5f, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Expand Down Expand Up @@ -378,13 +376,6 @@ struct OpHandler<UnaryCompositeOpType::POW> {
}
};

template <>
struct OpHandler<UnaryCompositeOpType::RDIV> {
static Tensor handle(uint8_t queue_id, const Tensor& input_tensor, float value, const std::optional<MemoryConfig>& memory_config, std::optional<Tensor> optional_output_tensor){
return _rdiv(queue_id, input_tensor, value, memory_config, optional_output_tensor);
}
};

template <>
struct OpHandler<UnaryCompositeOpType::HARDSHRINK> {
static Tensor handle(const Tensor& t1, float lambd, const std::optional<MemoryConfig>& mem_cfg ) {
Expand Down
5 changes: 2 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,9 @@ struct ExecuteRdiv {
uint8_t queue_id,
const Tensor& input_tensor,
float value,
string round_mode = "None",
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt) {
return OpHandler<UnaryCompositeOpType::RDIV>::handle(queue_id, input_tensor, value, memory_config.value_or(input_tensor.memory_config()), optional_output_tensor);
}
std::optional<Tensor> optional_output_tensor = std::nullopt);
};

} // namespace unary
Expand Down
67 changes: 66 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,59 @@ void bind_unary_operation_with_dim_parameter(
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_rdiv(py::module& module, const unary_operation_t& operation, const std::string& parameter_name_a, const std::string& parameter_a_doc, const std::string& parameter_name_b, const std::string& parameter_b_doc, const std::string parameter_b_value, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, {2}: float, *, {4}: string, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
{7}
Args:
* :attr:`input_tensor`
* :attr:`{2}` (float): {3}
Keyword args:
* :attr:`{4}` (string): {5} , Default value = {6}
* :attr:`memory_config` [ttnn.MemoryConfig]: memory config for the output tensor
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, {2}, {4} = {6})
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
parameter_name_a,
parameter_a_doc,
parameter_name_b,
parameter_b_doc,
parameter_b_value,
description);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_operation_t& self,
const ttnn::Tensor& input_tensor,
float parameter_a,
string parameter_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const uint8_t& queue_id) {
return self(queue_id, input_tensor, parameter_a, parameter_b, memory_config, output_tensor);
},
py::arg("input_tensor"),
py::arg(parameter_name_a.c_str()),
py::kw_only(),
py::arg(parameter_name_b.c_str()) = parameter_b_value,
py::arg("memory_config") = std::nullopt,
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = 0});
}

template <typename unary_operation_t>
void bind_softplus(py::module& module, const unary_operation_t& operation) {
Expand Down Expand Up @@ -459,6 +512,7 @@ void bind_softplus(py::module& module, const unary_operation_t& operation) {
py::arg("output_tensor") = std::nullopt,
py::arg("queue_id") = 0});
}

template <typename unary_operation_t>
void bind_sigmoid_accurate(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
Expand Down Expand Up @@ -1222,7 +1276,6 @@ void py_module(py::module& module) {
// Unaries with float parameter
detail::bind_unary_operation_with_float_parameter(module, ttnn::elu, "alpha", "The alpha parameter for the ELU function", "");
detail::bind_unary_operation_with_float_parameter(module, ttnn::rsub, "value", "subtrahent value which is actually calculated as minuend", "Returns tensor with respective elements of the input tensor subtracted from the value.");
detail::bind_unary_operation_with_float_parameter(module, ttnn::rdiv, "value", "denominator value which is actually calculated as numerator float value >= 0 ", "Returns tensor with scalar value divided by each of respective elements of the input tensor.");
detail::bind_unary_operation_with_float_parameter(module, ttnn::heaviside, "value", "The value parameter for the Heaviside function", "");
detail::bind_unary_operation_with_float_parameter(module, ttnn::leaky_relu, "slope", "The slope parameter for the Leaky ReLU function", "");
detail::bind_unary_operation_with_float_parameter(module, ttnn::relu_max, "upper_limit", "The max value for ReLU function", "This function caps off the input to a max value and a min value of 0");
Expand Down Expand Up @@ -1366,6 +1419,18 @@ void py_module(py::module& module) {
ttnn::rpow,
"exponent", "exponent value",
R"doc(Performs rpow function on :attr:`input_tensor`, :attr:`exponent`.)doc");

detail::bind_unary_rdiv(
module,
ttnn::rdiv,
"value", "denominator value which is actually calculated as numerator float value >= 0",
"round_mode", "rounding_mode value", "None",
R"doc(Performs the element-wise division of a scalar ``value`` by a tensor ``input`` and rounds the result using round_mode. Support provided only for Wormhole_B0.
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.)doc");

}

} // namespace unary
Expand Down
12 changes: 12 additions & 0 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,4 +638,16 @@ def _golden_function_frac(input_tensor_a, *args, **kwargs):


ttnn.attach_golden_function(ttnn._ttnn.operations.unary.frac, golden_function=_golden_function_frac)


def _golden_function_rdiv(input_tensor_a, value, *args, round_mode=None, **kwargs):
import torch

if round_mode == "None":
round_mode = None

return torch.div(torch.full_like(input_tensor_a, value), input_tensor_a, rounding_mode=round_mode)


ttnn.attach_golden_function(ttnn._ttnn.operations.unary.rdiv, golden_function=_golden_function_rdiv)
__all__ = []

0 comments on commit d0c11b1

Please sign in to comment.