Skip to content

Commit

Permalink
#0: Add PROD support
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Apr 12, 2024
1 parent 817dc0a commit e7ef646
Show file tree
Hide file tree
Showing 45 changed files with 1,320 additions and 565 deletions.
4 changes: 4 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,8 @@ Other Operations

.. autofunction:: tt_lib.tensor.prod

.. autofunction:: tt_lib.tensor.tiled_prod

.. autofunction:: tt_lib.tensor.addcmul

.. autofunction:: tt_lib.tensor.addcdiv
Expand Down Expand Up @@ -830,6 +832,8 @@ Other Operations
Backward Operations
===================

.. autofunction:: tt_lib.tensor.prod_bw

.. autofunction:: tt_lib.tensor.addalpha_bw

.. autofunction:: tt_lib.tensor.addcmul_bw
Expand Down
8 changes: 4 additions & 4 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
"tt_lib_op": tt_lib_ops.arange,
"pytorch_op": pytorch_ops.arange,
},
"prod": {
"tt_lib_op": tt_lib_ops.prod,
"pytorch_op": pytorch_ops.prod,
},
# stats
"stats-var_hw": {
"tt_lib_op": tt_lib_ops.var_hw,
Expand Down Expand Up @@ -590,10 +594,6 @@
"tt_lib_op": tt_lib_ops.eltwise_subalpha,
"pytorch_op": pytorch_ops.subalpha,
},
"eltwise-prod": {
"tt_lib_op": tt_lib_ops.eltwise_prod,
"pytorch_op": pytorch_ops.prod,
},
"eltwise-addalpha": {
"tt_lib_op": tt_lib_ops.eltwise_addalpha,
"pytorch_op": pytorch_ops.addalpha,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def custom_compare(*args, **kwargs):
return result


shapes = ([[1, 1, 32, 32]], [[4, 3, 32, 64]])
shapes = ([[1, 1, 32, 32]], [[1, 3, 320, 64]])
if is_wormhole_b0():
shapes = (shapes[0],)

Expand All @@ -60,59 +60,58 @@ def custom_compare(*args, **kwargs):
list(
product(
(
# "lerp_binary",
# "lerp_ternary",
# "addcmul",
# "addcdiv",
# "min",
# "max",
# "swish",
# "log1p",
# "softplus",
# "mish",
# "silu",
# "polyval",
# "mac",
# "cbrt",
# "threshold",
# "hypot",
# "hardswish",
# "hardsigmoid",
# "ones_like",
# "zeros_like",
# "full_like",
# "ones",
# "empty",
# "zeros",
# "full",
# "arange",
# "hardshrink",
# "softshrink",
# "sinh",
# "cosh",
# "tanhshrink",
# "xlogy",
# "asinh",
# "acosh",
# "atanh",
# "atan2",
# "subalpha",
# "bias_gelu_unary",
# "addalpha",
# "logit",
# "logical_ori",
# "logical_xor",
# "logical_xori",
# "logical_noti",
# "logical_andi",
# "isclose",
# "digamma",
# "lgamma",
# "multigammaln",
# "polygamma",
# "nextafter",
# "scatter",
"prod",
"lerp_binary",
"lerp_ternary",
"addcmul",
"addcdiv",
"min",
"max",
"swish",
"log1p",
"softplus",
"mish",
"silu",
"polyval",
"mac",
"cbrt",
"threshold",
"hypot",
"hardswish",
"hardsigmoid",
"ones_like",
"zeros_like",
"full_like",
"ones",
"empty",
"zeros",
"full",
"arange",
"hardshrink",
"softshrink",
"sinh",
"cosh",
"tanhshrink",
"xlogy",
"asinh",
"acosh",
"atanh",
"atan2",
"subalpha",
"bias_gelu_unary",
"addalpha",
"logit",
"logical_ori",
"logical_xor",
"logical_xori",
"logical_noti",
"logical_andi",
"isclose",
"digamma",
"lgamma",
"multigammaln",
"polygamma",
"nextafter",
"scatter",
),
shapes,
)
Expand All @@ -129,7 +128,6 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
options["hypot"] = (1, 100)
options["atan2"] = (-100, 100)
options["cbrt"] = (-1000, 1000)
options["prod"] = (1, 1.5)
options["hardsigmoid"] = (-100, 100)
options["hardswish"] = (-100, 100)
options["hardshrink"] = (-100, 100)
Expand Down Expand Up @@ -164,13 +162,6 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
torch.int32,
)
]
elif fn in ["prod"]: # "prod_cpu" not implemented for 'BFloat16'
datagen_func = [
generation_funcs.gen_func_with_cast(
partial(generator, low=options[fn][0], high=options[fn][1]),
torch.float32,
)
]
else:
datagen_func = [
generation_funcs.gen_func_with_cast(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import random
from functools import partial
import tt_lib as ttl


from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)

mem_configs = [
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
]


@pytest.mark.parametrize(
"dim",
(3, 2, 1, 0, -1, -2, -3, -4),
)
@pytest.mark.parametrize("all_dimensions", [False, True])
@pytest.mark.parametrize(
"input_shapes",
[
[[1, 1, 32, 32]],
[[4, 3, 32, 32]],
[[2, 2, 32, 32]],
# [[6, 4, 32, 32]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
# [[1, 1, 320, 320]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
# [[1, 3, 320, 64]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
],
)
@pytest.mark.parametrize(
"dst_mem_config",
mem_configs,
)
class TestProd:
def test_run_prod_op(
self,
all_dimensions,
dim,
input_shapes,
dst_mem_config,
device,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=1, high=1.5), torch.bfloat16)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update(
{
"all_dimensions": all_dimensions,
"dim": dim,
}
)
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_pcc

run_single_pytorch_test(
"prod",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
7 changes: 5 additions & 2 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,11 @@ def xlogy(x, y, *args, **kwargs):
return torch.xlogy(x, y)


def prod(x, *args, **kwargs):
return torch.prod(x, 0)
def prod(x, *args, all_dimensions, dim, **kwargs):
if all_dimensions:
result = torch.prod(x)
return result.view(1, 1, 1, 1)
return torch.prod(x, dim, keepdim=True)


def ldexp(x, y, *args, **kwargs):
Expand Down
23 changes: 22 additions & 1 deletion tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,28 @@ def arange(
return tt2torch_tensor(t1)


@setup_host_and_device
def prod(
x,
*args,
all_dimensions,
dim,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.prod(t0, all_dimensions, dim, output_mem_config=output_mem_config)
output = tt2torch_tensor(t1)
if all_dimensions:
return output[:1, :1, :1, :1]
else:
return output


@setup_host_and_device
def eltwise_logical_andi(
x,
Expand Down Expand Up @@ -2262,7 +2284,6 @@ def binary_op(
eltwise_isneginf = make_unary_op(ttl.tensor.isneginf)
eltwise_isnan = make_unary_op(ttl.tensor.isnan)
eltwise_logical_not_unary = make_unary_op(ttl.tensor.logical_not_unary)
eltwise_prod = make_unary_op(ttl.tensor.prod)
eltwise_i0 = make_unary_op(ttl.tensor.i0)

################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
data_gen_pt_tt,
data_gen_pt_tt_prod,
compare_results,
)


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])), # 0
(torch.Size([1, 1, 320, 384])), # 1
(torch.Size([4, 2, 32, 32])), # 2
(torch.Size([1, 3, 320, 384])), # 3
(torch.Size([4, 3, 32, 32])), # 4
(torch.Size([4, 3, 64, 64])), # 5
(torch.Size([4, 3, 320, 320])), # 6
(torch.Size([4, 3, 32, 32])), # 7
(torch.Size([1, 3, 320, 320])), # 8
(torch.Size([1, 4, 320, 384])), # 9
(torch.Size([4, 4, 32, 32])), # 10
(torch.Size([5, 4, 32, 32])), # 11
(torch.Size([6, 4, 32, 32])), # 12
(torch.Size([4, 5, 32, 32])), # 13
(torch.Size([4, 6, 32, 32])), # 14
(torch.Size([4, 10, 32, 32])), # 15
(torch.Size([4, 20, 32, 32])), # 16
(torch.Size([4, 30, 32, 32])), # 17
(torch.Size([4, 31, 32, 32])), # 18
(torch.Size([4, 32, 32, 32])), # 19
(torch.Size([4, 33, 32, 32])), # 20
(torch.Size([4, 63, 32, 32])), # 21
(torch.Size([4, 64, 32, 32])), # 22
(torch.Size([32, 64, 32, 32])), # 23
),
)
@pytest.mark.parametrize(
"dim",
[-4, -3, -2, -1, 0, 1, 2, 3],
)
@pytest.mark.parametrize("all_dimensions", [True, False])
def test_bw_prod(input_shapes, all_dimensions, dim, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt_prod(input_shapes, device, all_dimensions, dim)
if all_dimensions == False:
pyt_y = torch.prod(in_data, dim=dim, keepdim=True)
else:
pyt_y = torch.prod(in_data).view(1, 1, 1, 1)
tt_output_tensor_on_device = tt_lib.tensor.prod_bw(grad_tensor, input_tensor, all_dimensions, dim)
in_data.retain_grad()
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
Loading

0 comments on commit e7ef646

Please sign in to comment.