Skip to content

Commit

Permalink
#0: Add PROD forward and backward support
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Apr 17, 2024
1 parent 4c91b0a commit 619fdb6
Show file tree
Hide file tree
Showing 41 changed files with 1,736 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,10 @@ Other Operations

.. autofunction:: tt_lib.tensor.xlogy

.. autofunction:: tt_lib.tensor.prod

.. autofunction:: tt_lib.tensor.tiled_prod

.. autofunction:: tt_lib.tensor.addcmul

.. autofunction:: tt_lib.tensor.addcdiv
Expand Down Expand Up @@ -830,6 +834,8 @@ Other Operations
Backward Operations
===================

.. autofunction:: tt_lib.tensor.prod_bw

.. autofunction:: tt_lib.tensor.addalpha_bw

.. autofunction:: tt_lib.tensor.addcmul_bw
Expand Down
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
"tt_lib_op": tt_lib_ops.arange,
"pytorch_op": pytorch_ops.arange,
},
"prod": {
"tt_lib_op": tt_lib_ops.prod,
"pytorch_op": pytorch_ops.prod,
},
# stats
"stats-var_hw": {
"tt_lib_op": tt_lib_ops.var_hw,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import random
from functools import partial
import tt_lib as ttl


from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)

mem_configs = [
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
]


@pytest.mark.parametrize(
"dim",
(3, 2, 1, 0, -1, -2, -3, -4),
)
@pytest.mark.parametrize("all_dimensions", [False, True])
@pytest.mark.parametrize(
"input_shapes",
[
[[1, 1, 32, 32]],
[[4, 3, 32, 32]],
[[2, 2, 32, 32]],
# [[6, 4, 32, 32]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
# [[1, 1, 320, 320]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
# [[1, 3, 320, 64]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan )
],
)
@pytest.mark.parametrize(
"dst_mem_config",
mem_configs,
)
class TestProd:
def test_run_prod_op(
self,
all_dimensions,
dim,
input_shapes,
dst_mem_config,
device,
):
datagen_func = [
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=1, high=1.5), torch.bfloat16)
]
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
test_args.update(
{
"all_dimensions": all_dimensions,
"dim": dim,
}
)
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_pcc

run_single_pytorch_test(
"prod",
input_shapes,
datagen_func,
comparison_func,
device,
test_args,
)
7 changes: 7 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,13 @@ def xlogy(x, y, *args, **kwargs):
return torch.xlogy(x, y)


def prod(x, *args, all_dimensions, dim, **kwargs):
if all_dimensions:
result = torch.prod(x)
return result.view(1, 1, 1, 1)
return torch.prod(x, dim, keepdim=True)


def ldexp(x, y, *args, **kwargs):
return torch.ldexp(x, y)

Expand Down
22 changes: 22 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,28 @@ def arange(
return tt2torch_tensor(t1)


@setup_host_and_device
def prod(
x,
*args,
all_dimensions,
dim,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.prod(t0, all_dimensions, dim, output_mem_config=output_mem_config)
output = tt2torch_tensor(t1)
if all_dimensions:
return output[:1, :1, :1, :1]
else:
return output


@setup_host_and_device
def eltwise_logical_andi(
x,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import (
data_gen_pt_tt,
data_gen_pt_tt_prod,
compare_results,
)


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])), # 0
(torch.Size([1, 1, 320, 384])), # 1
(torch.Size([4, 2, 32, 32])), # 2
(torch.Size([1, 3, 320, 384])), # 3
(torch.Size([4, 3, 32, 32])), # 4
(torch.Size([4, 3, 64, 64])), # 5
(torch.Size([4, 3, 320, 320])), # 6
(torch.Size([4, 3, 32, 32])), # 7
(torch.Size([1, 3, 320, 320])), # 8
(torch.Size([1, 4, 320, 384])), # 9
(torch.Size([4, 4, 32, 32])), # 10
(torch.Size([5, 4, 32, 32])), # 11
(torch.Size([6, 4, 32, 32])), # 12
(torch.Size([4, 5, 32, 32])), # 13
(torch.Size([4, 6, 32, 32])), # 14
(torch.Size([4, 10, 32, 32])), # 15
(torch.Size([4, 20, 32, 32])), # 16
(torch.Size([4, 30, 32, 32])), # 17
(torch.Size([4, 31, 32, 32])), # 18
(torch.Size([4, 32, 32, 32])), # 19
(torch.Size([4, 33, 32, 32])), # 20
(torch.Size([4, 63, 32, 32])), # 21
(torch.Size([4, 64, 32, 32])), # 22
(torch.Size([32, 64, 32, 32])), # 23
),
)
@pytest.mark.parametrize(
"dim",
[-4, -3, -2, -1, 0, 1, 2, 3],
)
@pytest.mark.parametrize("all_dimensions", [True, False])
def test_bw_prod(input_shapes, all_dimensions, dim, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt_prod(input_shapes, device, all_dimensions, dim)
if all_dimensions == False:
pyt_y = torch.prod(in_data, dim=dim, keepdim=True)
else:
pyt_y = torch.prod(in_data).view(1, 1, 1, 1)
tt_output_tensor_on_device = tt_lib.tensor.prod_bw(grad_tensor, input_tensor, all_dimensions, dim)
in_data.retain_grad()
pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,39 @@ def data_gen_with_val(input_shapes, device, required_grad=False, val=1, is_row_m
return pt_tensor, tt_tensor


def data_gen_pt_tt_prod(input_shapes, device, all_dimensions, dim, required_grad=False):
torch.manual_seed(213919)
pt_tensor_temp = torch.zeros(input_shapes, requires_grad=required_grad).bfloat16()
shape_Required = torch.Size(
[
input_shapes[0] if (dim != 0 and dim != -4) else 1,
input_shapes[1] if (dim != 1 and dim != -3) else 1,
input_shapes[2] if (dim != 2 and dim != -2) else 1,
input_shapes[3] if (dim != 3 and dim != -1) else 1,
]
)
if all_dimensions == False and (dim == 1 or dim == 0 or dim == -4 or dim == -3):
pt_tensor = torch.randn(shape_Required, requires_grad=required_grad).bfloat16()
tt_tensor = (
tt_lib.tensor.Tensor(pt_tensor, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
return pt_tensor, tt_tensor
elif all_dimensions == False:
pt_tensor = torch.randn(shape_Required, requires_grad=required_grad).bfloat16()
if dim == 3 or dim == -1:
pt_tensor_temp[:, :, :, :1] = pt_tensor
elif dim == 2 or dim == -2:
pt_tensor_temp[:, :, :1, :] = pt_tensor
else:
shape_Required = torch.Size([1, 1, 1, 1])
pt_tensor = torch.randn(shape_Required, requires_grad=required_grad).bfloat16()
pt_tensor_temp[:1, :1, :1, :1] = pt_tensor
tt_tensor = (
tt_lib.tensor.Tensor(pt_tensor_temp, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
return pt_tensor, tt_tensor


def compare_results(tt_tensor, golden_tensor, pcc=0.99):
status = True
for i in range(len(tt_tensor)):
Expand Down
78 changes: 78 additions & 0 deletions tests/tt_eager/python_api_testing/unit_testing/test_prod_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from loguru import logger
from functools import partial

import tt_lib as ttl
from models.utility_functions import comp_allclose_and_pcc

from tests.tt_eager.python_api_testing.sweep_tests import (
comparison_funcs,
generation_funcs,
)
from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import (
run_single_pytorch_test,
)


def get_tensors(input_shape, output_shape, device):
torch.manual_seed(2023)
npu_dtype = ttl.tensor.DataType.BFLOAT16
cpu_dtype = torch.bfloat16
npu_layout = ttl.tensor.Layout.TILE

torch_input = torch.randint(1, 5, input_shape, dtype=cpu_dtype)
torch_output = torch.randint(1, 5, output_shape, dtype=cpu_dtype)
tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
tt_output = ttl.tensor.Tensor(torch_output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)

return tt_input, tt_output, torch_input


@pytest.mark.parametrize(
"shapes",
(
([1, 1, 32, 32]),
([1, 4, 32, 32]),
([2, 2, 32, 32]),
# ([6, 4, 32, 32]), #Fails : expected result is inf but the result generated in nan
# ([1, 1, 320, 320]), #Fails : expected result is inf but the result generated in nan
# ([1, 3, 320, 64]), #Fails : expected result is inf but the result generated in nan
),
)
def test_prod(shapes, device):
output_shape = shapes.copy()

(tt_input, tt_output, torch_input) = get_tensors(shapes, shapes, device)

torch_output = torch.prod(torch_input)

cpu_layout = ttl.tensor.Layout.ROW_MAJOR
tt_output_cpu = (
ttl.operations.primary.prod_all(tt_input).cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch()
)
N, C, H, W = tt_output_cpu.shape
torch.set_printoptions(threshold=10000, precision=5, sci_mode=False)
logger.info("Input shape")
logger.info(torch_input.shape)
logger.info("TT Output")
logger.info(tt_output_cpu[0, 0, 0, 0])
logger.info("Torch Output")
logger.info(torch_output)

# test for equivalance
# TODO(Dongjin) : check while changing rtol after enabling fp32_dest_acc_en
rtol = atol = 0.12
# passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol)
passing, output_pcc = comp_allclose_and_pcc(
torch_output, tt_output_cpu[0, 0, 0, 0], pcc=0.999, rtol=rtol, atol=atol
)

logger.info(f"Out passing={passing}")
logger.info(f"Output pcc={output_pcc}")

assert passing
Loading

0 comments on commit 619fdb6

Please sign in to comment.