Skip to content

Commit

Permalink
#0: Add prod in composite
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Apr 12, 2024
1 parent 3d151cb commit 817dc0a
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 69 deletions.
2 changes: 2 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,8 @@ Other Operations

.. autofunction:: tt_lib.tensor.xlogy

.. autofunction:: tt_lib.tensor.prod

.. autofunction:: tt_lib.tensor.addcmul

.. autofunction:: tt_lib.tensor.addcdiv
Expand Down
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,10 @@
"tt_lib_op": tt_lib_ops.eltwise_subalpha,
"pytorch_op": pytorch_ops.subalpha,
},
"eltwise-prod": {
"tt_lib_op": tt_lib_ops.eltwise_prod,
"pytorch_op": pytorch_ops.prod,
},
"eltwise-addalpha": {
"tt_lib_op": tt_lib_ops.eltwise_addalpha,
"pytorch_op": pytorch_ops.addalpha,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def custom_compare(*args, **kwargs):
return result


shapes = ([[1, 1, 32, 32]], [[1, 3, 320, 64]])
shapes = ([[1, 1, 32, 32]], [[4, 3, 32, 64]])
if is_wormhole_b0():
shapes = (shapes[0],)

Expand All @@ -60,58 +60,59 @@ def custom_compare(*args, **kwargs):
list(
product(
(
"lerp_binary",
"lerp_ternary",
"addcmul",
"addcdiv",
"min",
"max",
"swish",
"log1p",
"softplus",
"mish",
"silu",
"polyval",
"mac",
"cbrt",
"threshold",
"hypot",
"hardswish",
"hardsigmoid",
"ones_like",
"zeros_like",
"full_like",
"ones",
"empty",
"zeros",
"full",
"arange",
"hardshrink",
"softshrink",
"sinh",
"cosh",
"tanhshrink",
"xlogy",
"asinh",
"acosh",
"atanh",
"atan2",
"subalpha",
"bias_gelu_unary",
"addalpha",
"logit",
"logical_ori",
"logical_xor",
"logical_xori",
"logical_noti",
"logical_andi",
"isclose",
"digamma",
"lgamma",
"multigammaln",
"polygamma",
"nextafter",
"scatter",
# "lerp_binary",
# "lerp_ternary",
# "addcmul",
# "addcdiv",
# "min",
# "max",
# "swish",
# "log1p",
# "softplus",
# "mish",
# "silu",
# "polyval",
# "mac",
# "cbrt",
# "threshold",
# "hypot",
# "hardswish",
# "hardsigmoid",
# "ones_like",
# "zeros_like",
# "full_like",
# "ones",
# "empty",
# "zeros",
# "full",
# "arange",
# "hardshrink",
# "softshrink",
# "sinh",
# "cosh",
# "tanhshrink",
# "xlogy",
# "asinh",
# "acosh",
# "atanh",
# "atan2",
# "subalpha",
# "bias_gelu_unary",
# "addalpha",
# "logit",
# "logical_ori",
# "logical_xor",
# "logical_xori",
# "logical_noti",
# "logical_andi",
# "isclose",
# "digamma",
# "lgamma",
# "multigammaln",
# "polygamma",
# "nextafter",
# "scatter",
"prod",
),
shapes,
)
Expand All @@ -128,6 +129,7 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
options["hypot"] = (1, 100)
options["atan2"] = (-100, 100)
options["cbrt"] = (-1000, 1000)
options["prod"] = (1, 1.5)
options["hardsigmoid"] = (-100, 100)
options["hardswish"] = (-100, 100)
options["hardshrink"] = (-100, 100)
Expand Down Expand Up @@ -162,6 +164,13 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
torch.int32,
)
]
elif fn in ["prod"]: # "prod_cpu" not implemented for 'BFloat16'
datagen_func = [
generation_funcs.gen_func_with_cast(
partial(generator, low=options[fn][0], high=options[fn][1]),
torch.float32,
)
]
else:
datagen_func = [
generation_funcs.gen_func_with_cast(
Expand Down
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,10 @@ def xlogy(x, y, *args, **kwargs):
return torch.xlogy(x, y)


def prod(x, *args, **kwargs):
return torch.prod(x, 0)


def ldexp(x, y, *args, **kwargs):
return torch.ldexp(x, y)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2262,6 +2262,7 @@ def binary_op(
eltwise_isneginf = make_unary_op(ttl.tensor.isneginf)
eltwise_isnan = make_unary_op(ttl.tensor.isnan)
eltwise_logical_not_unary = make_unary_op(ttl.tensor.logical_not_unary)
eltwise_prod = make_unary_op(ttl.tensor.prod)
eltwise_i0 = make_unary_op(ttl.tensor.i0)

################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_moreh_prod_dims(input_shape, dims, device):

cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_output_cpu = (
ttl.operations.primary.prod(tt_input, tt_output, dims=dims)
ttl.operations.primary.prod_nc(tt_input, tt_output, dims=dims)
.cpu()
.to(cpu_layout)
.unpad_from_tile(output_shape)
Expand Down
12 changes: 12 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "tt_eager/tensor/tensor_utils.hpp"
#include "tt_eager/tt_dnn/op_library/pad/pad_op.hpp"
#include "tt_numpy/functions.hpp"
#include "tt_dnn/op_library/prod/prod_nc_op.hpp"
namespace tt {

namespace tt_metal {
Expand Down Expand Up @@ -901,6 +902,17 @@ Tensor xlogy(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& o
return operation::decorate_as_composite(__func__, _xlogy)(input_a, input_b, output_mem_config);
}

Tensor _prod(const Tensor& input_a, const MemoryConfig& output_mem_config) {
std::vector<int64_t> dim = {0};
Shape input_shape = input_a.shape();
Shape required = { 1, input_shape[1], input_shape[2], input_shape[3]};
Tensor result = tt::operations::primary::prod_nc(input_a, zeros( required, input_a.dtype(), input_a.layout(), input_a.device(), output_mem_config), dim, output_mem_config);
return result;
}
Tensor prod(const Tensor& input_a, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _prod)(input_a, output_mem_config);
}

Tensor _variance_impl(
const Tensor& y, const Tensor& mean_y, Tensor& y_minus_mean_y, const MemoryConfig& output_mem_config) {
constexpr float correction = 0.0f;
Expand Down
5 changes: 5 additions & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ Tensor logical_noti(
float immediate,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// prod
Tensor prod(
const Tensor& input_a,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

/*
Returns a new tensor with the signed angles in radians between vectors

Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace operations {

namespace primary {

operation::ProgramWithCallbacks prod_nc(const Tensor &input, const Tensor &output, int64_t dim) {
operation::ProgramWithCallbacks prod_nc_format(const Tensor &input, const Tensor &output, int64_t dim) {
TT_ASSERT(dim == 0 || dim == 1);

////////////////////////////////////////////////////////////////////////////
Expand Down
13 changes: 3 additions & 10 deletions tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,15 @@ operation::ProgramWithCallbacks Prod::create_program(
auto& input = inputs.at(0);
auto& output = inputs.at(1);


if (dim == 0 || dim == 1) {
return prod_nc(input, output, dim);
} else {
return prod_nc(input, output, dim);
}
return prod_nc_format(input, output, dim);
}

inline Shape compute_output_shape(const Shape& input_shape, const int64_t& dim) {
auto output_shape = input_shape;
auto padding = output_shape.padding();
switch (dim) {
case 0:
case 1:
case 2:
case 3: output_shape[dim] = 1;
case 1: output_shape[dim] = 1;
break;
}

Expand Down Expand Up @@ -101,7 +94,7 @@ Tensor prod_(const Tensor& input, const int64_t& dim, const MemoryConfig& mem_co
return output;
}

Tensor prod(
Tensor prod_nc(
const Tensor& input,
const Tensor& output,
std::vector<int64_t>& dims,
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ struct Prod {
const auto attribute_values() const { return std::make_tuple(std::cref(this->dim)); }
};

operation::ProgramWithCallbacks prod_nc(const Tensor &input, const Tensor &output, int64_t dim);
operation::ProgramWithCallbacks prod_nc_format(const Tensor &input, const Tensor &output, int64_t dim);

Tensor prod_(
const Tensor &input,
std::optional<std::reference_wrapper<const Tensor>> output,
const int64_t &dim,
const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor prod(
Tensor prod_nc(
const Tensor &input,
const Tensor &output,
std::vector<int64_t> &dims,
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_lib/csrc/operations/primary/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,8 +839,8 @@ void py_module(py::module& m_primary) {
"Performs sum operation. Returns an output tensor.");

m_primary.def(
"prod",
&prod,
"prod_nc",
&prod_nc,
py::arg("input").noconvert(),
py::arg("output").noconvert(),
py::kw_only(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ namespace tt::tt_metal::detail{
detail::bind_unary_op(m_tensor, "swish", swish, R"doc(Returns tensor with the swish all of elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "mish", &mish, R"doc(Returns tensor with the mish activation of elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "cbrt", &cbrt, R"doc(Returns tensor with the cbrt activation of elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "prod", &prod, R"doc(Computes the prod function along all dimensions on the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "asinh", &asinh, R"doc(Returns tensor with the inverse hyperbolic sine of elements of the input tensor ``{0}`` in range [-1e-6, 1e6].
for +input , output = asinh(input)
for -input , output = -asinh(input))doc"
Expand Down

0 comments on commit 817dc0a

Please sign in to comment.