From 3f4a69f4fe9ab23bccc255892e384ae92f55fb4d Mon Sep 17 00:00:00 2001 From: VirdhatchaniKN Date: Mon, 19 Feb 2024 12:37:09 +0000 Subject: [PATCH] #0: Add PROD forward and backward support --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 6 + .../python_api_testing/sweep_tests/op_map.py | 4 + .../sweep_tests/pytests/tt_dnn/test_prod.py | 75 +++++++ .../sweep_tests/pytorch_ops.py | 7 + .../sweep_tests/tt_lib_ops.py | 22 ++ .../backward_ops/test_backward_prod.py | 64 ++++++ .../backward_ops/utility_funcs.py | 33 +++ .../unit_testing/test_prod_all.py | 78 +++++++ .../unit_testing/test_prod_nc.py | 90 ++++++++ tt_eager/tt_dnn/module.mk | 4 + .../op_library/backward/backward_ops.cpp | 107 ++++++++++ .../op_library/backward/backward_ops.hpp | 2 + .../op_library/composite/composite_ops.cpp | 78 +++++++ .../op_library/composite/composite_ops.hpp | 8 + .../eltwise_unary/eltwise_unary_op.cpp | 2 + .../eltwise_unary/eltwise_unary_op.hpp | 4 +- .../prod/kernels/compute/prod_all.cpp | 71 +++++++ .../prod/kernels/compute/prod_nc.cpp | 58 +++++ .../prod/kernels/dataflow/reader_prod_nc.cpp | 65 ++++++ .../prod/kernels/dataflow/utils.hpp | 16 ++ .../op_library/prod/prod_nc/prod_nc.cpp | 201 ++++++++++++++++++ .../tt_dnn/op_library/prod/prod_nc_op.cpp | 121 +++++++++++ .../tt_dnn/op_library/prod/prod_nc_op.hpp | 53 +++++ .../tt_dnn/op_library/prod/prod_op_all.cpp | 60 ++++++ .../tt_dnn/op_library/prod/prod_op_all.hpp | 45 ++++ .../single_core/prod_op_all_single_core.cpp | 137 ++++++++++++ .../tt_lib/csrc/operations/primary/module.hpp | 20 ++ .../tt_lib_bindings_tensor_backward_ops.cpp | 19 ++ .../tt_lib_bindings_tensor_composite_ops.cpp | 17 ++ .../csrc/tt_lib_bindings_tensor_xary_ops.cpp | 1 + tt_eager/tt_numpy/functions.hpp | 114 ++++++++++ tt_metal/common/bfloat16.hpp | 3 + .../metal/llk_api/llk_math_unary_sfpu_api.h | 1 + .../llk_sfpu/ckernel_sfpu_tiled_prod.h | 34 +++ .../llk_math_eltwise_unary_sfpu_tiled_prod.h | 28 +++ .../grayskull/metal/llk_api/llk_sfpu_types.h | 1 + .../metal/llk_api/llk_math_unary_sfpu_api.h | 1 + .../llk_sfpu/ckernel_sfpu_tiled_prod.h | 34 +++ .../llk_math_eltwise_unary_sfpu_tiled_prod.h | 28 +++ .../metal/llk_api/llk_sfpu_types.h | 1 + tt_metal/include/compute_kernel_api.h | 24 +++ 41 files changed, 1736 insertions(+), 1 deletion(-) create mode 100644 tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_prod.py create mode 100644 tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_prod.py create mode 100644 tests/tt_eager/python_api_testing/unit_testing/test_prod_all.py create mode 100644 tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py create mode 100644 tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_all.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_nc.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/kernels/dataflow/reader_prod_nc.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/kernels/dataflow/utils.hpp create mode 100644 tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp create mode 100644 tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp create mode 100644 tt_eager/tt_dnn/op_library/prod/prod_op_all.hpp create mode 100644 tt_eager/tt_dnn/op_library/prod/single_core/prod_op_all_single_core.cpp create mode 100644 tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_tiled_prod.h create mode 100644 tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_tiled_prod.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_tiled_prod.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_tiled_prod.h diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 99fc4f551d25..400320ef20a5 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -785,6 +785,10 @@ Other Operations .. autofunction:: tt_lib.tensor.xlogy +.. autofunction:: tt_lib.tensor.prod + +.. autofunction:: tt_lib.tensor.tiled_prod + .. autofunction:: tt_lib.tensor.addcmul .. autofunction:: tt_lib.tensor.addcdiv @@ -830,6 +834,8 @@ Other Operations Backward Operations =================== +.. autofunction:: tt_lib.tensor.prod_bw + .. autofunction:: tt_lib.tensor.addalpha_bw .. autofunction:: tt_lib.tensor.addcmul_bw diff --git a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py index e0d39c14d236..2859394e6079 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py @@ -48,6 +48,10 @@ "tt_lib_op": tt_lib_ops.arange, "pytorch_op": pytorch_ops.arange, }, + "prod": { + "tt_lib_op": tt_lib_ops.prod, + "pytorch_op": pytorch_ops.prod, + }, # stats "stats-var_hw": { "tt_lib_op": tt_lib_ops.var_hw, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_prod.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_prod.py new file mode 100644 index 000000000000..8ba79b93102e --- /dev/null +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_prod.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: © 2023-24 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import random +from functools import partial +import tt_lib as ttl + + +from tests.tt_eager.python_api_testing.sweep_tests import ( + comparison_funcs, + generation_funcs, +) +from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( + run_single_pytorch_test, +) + +mem_configs = [ + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM), + ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), +] + + +@pytest.mark.parametrize( + "dim", + (3, 2, 1, 0, -1, -2, -3, -4), +) +@pytest.mark.parametrize("all_dimensions", [False, True]) +@pytest.mark.parametrize( + "input_shapes", + [ + [[1, 1, 32, 32]], + [[4, 3, 32, 32]], + [[2, 2, 32, 32]], + # [[6, 4, 32, 32]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan ) + # [[1, 1, 320, 320]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan ) + # [[1, 3, 320, 64]], #Fails for all_dimensions = True ( expected result is inf but the result generated in nan ) + ], +) +@pytest.mark.parametrize( + "dst_mem_config", + mem_configs, +) +class TestProd: + def test_run_prod_op( + self, + all_dimensions, + dim, + input_shapes, + dst_mem_config, + device, + ): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=1, high=1.5), torch.bfloat16) + ] + test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0] + test_args.update( + { + "all_dimensions": all_dimensions, + "dim": dim, + } + ) + test_args.update({"output_mem_config": dst_mem_config}) + comparison_func = comparison_funcs.comp_pcc + + run_single_pytorch_test( + "prod", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 4234dfa799cd..f37a065243ea 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -863,6 +863,13 @@ def xlogy(x, y, *args, **kwargs): return torch.xlogy(x, y) +def prod(x, *args, all_dimensions, dim, **kwargs): + if all_dimensions: + result = torch.prod(x) + return result.view(1, 1, 1, 1) + return torch.prod(x, dim, keepdim=True) + + def ldexp(x, y, *args, **kwargs): return torch.ldexp(x, y) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index e67804c63724..6af1b3e0b895 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -1301,6 +1301,28 @@ def arange( return tt2torch_tensor(t1) +@setup_host_and_device +def prod( + x, + *args, + all_dimensions, + dim, + device, + dtype, + layout, + input_mem_config, + output_mem_config, + **kwargs, +): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t1 = ttl.tensor.prod(t0, all_dimensions, dim, output_mem_config=output_mem_config) + output = tt2torch_tensor(t1) + if all_dimensions: + return output[:1, :1, :1, :1] + else: + return output + + @setup_host_and_device def eltwise_logical_andi( x, diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_prod.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_prod.py new file mode 100644 index 000000000000..d66971dacb97 --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/test_backward_prod.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import tt_lib +from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import ( + data_gen_pt_tt, + data_gen_pt_tt_prod, + compare_results, +) + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), # 0 + (torch.Size([1, 1, 320, 384])), # 1 + (torch.Size([4, 2, 32, 32])), # 2 + (torch.Size([1, 3, 320, 384])), # 3 + (torch.Size([4, 3, 32, 32])), # 4 + (torch.Size([4, 3, 64, 64])), # 5 + (torch.Size([4, 3, 320, 320])), # 6 + (torch.Size([4, 3, 32, 32])), # 7 + (torch.Size([1, 3, 320, 320])), # 8 + (torch.Size([1, 4, 320, 384])), # 9 + (torch.Size([4, 4, 32, 32])), # 10 + (torch.Size([5, 4, 32, 32])), # 11 + (torch.Size([6, 4, 32, 32])), # 12 + (torch.Size([4, 5, 32, 32])), # 13 + (torch.Size([4, 6, 32, 32])), # 14 + (torch.Size([4, 10, 32, 32])), # 15 + (torch.Size([4, 20, 32, 32])), # 16 + (torch.Size([4, 30, 32, 32])), # 17 + (torch.Size([4, 31, 32, 32])), # 18 + (torch.Size([4, 32, 32, 32])), # 19 + (torch.Size([4, 33, 32, 32])), # 20 + (torch.Size([4, 63, 32, 32])), # 21 + (torch.Size([4, 64, 32, 32])), # 22 + (torch.Size([32, 64, 32, 32])), # 23 + ), +) +@pytest.mark.parametrize( + "dim", + [-4, -3, -2, -1, 0, 1, 2, 3], +) +@pytest.mark.parametrize("all_dimensions", [True, False]) +def test_bw_prod(input_shapes, all_dimensions, dim, device): + in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True) + grad_data, grad_tensor = data_gen_pt_tt_prod(input_shapes, device, all_dimensions, dim) + if all_dimensions == False: + pyt_y = torch.prod(in_data, dim=dim, keepdim=True) + else: + pyt_y = torch.prod(in_data).view(1, 1, 1, 1) + tt_output_tensor_on_device = tt_lib.tensor.prod_bw(grad_tensor, input_tensor, all_dimensions, dim) + in_data.retain_grad() + pyt_y.backward(gradient=grad_data) + + golden_tensor = [in_data.grad] + + comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) + + assert comp_pass diff --git a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/utility_funcs.py b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/utility_funcs.py index aba1ed576298..b83d435d32db 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/backward_ops/utility_funcs.py +++ b/tests/tt_eager/python_api_testing/unit_testing/backward_ops/utility_funcs.py @@ -52,6 +52,39 @@ def data_gen_with_val(input_shapes, device, required_grad=False, val=1, is_row_m return pt_tensor, tt_tensor +def data_gen_pt_tt_prod(input_shapes, device, all_dimensions, dim, required_grad=False): + torch.manual_seed(213919) + pt_tensor_temp = torch.zeros(input_shapes, requires_grad=required_grad).bfloat16() + shape_Required = torch.Size( + [ + input_shapes[0] if (dim != 0 and dim != -4) else 1, + input_shapes[1] if (dim != 1 and dim != -3) else 1, + input_shapes[2] if (dim != 2 and dim != -2) else 1, + input_shapes[3] if (dim != 3 and dim != -1) else 1, + ] + ) + if all_dimensions == False and (dim == 1 or dim == 0 or dim == -4 or dim == -3): + pt_tensor = torch.randn(shape_Required, requires_grad=required_grad).bfloat16() + tt_tensor = ( + tt_lib.tensor.Tensor(pt_tensor, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device) + ) + return pt_tensor, tt_tensor + elif all_dimensions == False: + pt_tensor = torch.randn(shape_Required, requires_grad=required_grad).bfloat16() + if dim == 3 or dim == -1: + pt_tensor_temp[:, :, :, :1] = pt_tensor + elif dim == 2 or dim == -2: + pt_tensor_temp[:, :, :1, :] = pt_tensor + else: + shape_Required = torch.Size([1, 1, 1, 1]) + pt_tensor = torch.randn(shape_Required, requires_grad=required_grad).bfloat16() + pt_tensor_temp[:1, :1, :1, :1] = pt_tensor + tt_tensor = ( + tt_lib.tensor.Tensor(pt_tensor_temp, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device) + ) + return pt_tensor, tt_tensor + + def compare_results(tt_tensor, golden_tensor, pcc=0.99): status = True for i in range(len(tt_tensor)): diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_prod_all.py b/tests/tt_eager/python_api_testing/unit_testing/test_prod_all.py new file mode 100644 index 000000000000..2dd609a58328 --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/test_prod_all.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from loguru import logger +from functools import partial + +import tt_lib as ttl +from models.utility_functions import comp_allclose_and_pcc + +from tests.tt_eager.python_api_testing.sweep_tests import ( + comparison_funcs, + generation_funcs, +) +from tests.tt_eager.python_api_testing.sweep_tests.run_pytorch_ci_tests import ( + run_single_pytorch_test, +) + + +def get_tensors(input_shape, output_shape, device): + torch.manual_seed(2023) + npu_dtype = ttl.tensor.DataType.BFLOAT16 + cpu_dtype = torch.bfloat16 + npu_layout = ttl.tensor.Layout.TILE + + torch_input = torch.randint(1, 5, input_shape, dtype=cpu_dtype) + torch_output = torch.randint(1, 5, output_shape, dtype=cpu_dtype) + tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + tt_output = ttl.tensor.Tensor(torch_output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + + return tt_input, tt_output, torch_input + + +@pytest.mark.parametrize( + "shapes", + ( + ([1, 1, 32, 32]), + ([1, 4, 32, 32]), + ([2, 2, 32, 32]), + # ([6, 4, 32, 32]), #Fails : expected result is inf but the result generated in nan + # ([1, 1, 320, 320]), #Fails : expected result is inf but the result generated in nan + # ([1, 3, 320, 64]), #Fails : expected result is inf but the result generated in nan + ), +) +def test_prod(shapes, device): + output_shape = shapes.copy() + + (tt_input, tt_output, torch_input) = get_tensors(shapes, shapes, device) + + torch_output = torch.prod(torch_input) + + cpu_layout = ttl.tensor.Layout.ROW_MAJOR + tt_output_cpu = ( + ttl.operations.primary.prod_all(tt_input).cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + ) + N, C, H, W = tt_output_cpu.shape + torch.set_printoptions(threshold=10000, precision=5, sci_mode=False) + logger.info("Input shape") + logger.info(torch_input.shape) + logger.info("TT Output") + logger.info(tt_output_cpu[0, 0, 0, 0]) + logger.info("Torch Output") + logger.info(torch_output) + + # test for equivalance + # TODO(Dongjin) : check while changing rtol after enabling fp32_dest_acc_en + rtol = atol = 0.12 + # passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol) + passing, output_pcc = comp_allclose_and_pcc( + torch_output, tt_output_cpu[0, 0, 0, 0], pcc=0.999, rtol=rtol, atol=atol + ) + + logger.info(f"Out passing={passing}") + logger.info(f"Output pcc={output_pcc}") + + assert passing diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py b/tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py new file mode 100644 index 000000000000..c761082e37a6 --- /dev/null +++ b/tests/tt_eager/python_api_testing/unit_testing/test_prod_nc.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from loguru import logger + +import tt_lib as ttl +from models.utility_functions import comp_allclose_and_pcc + +TILE_HEIGHT = 32 +TILE_WIDTH = 32 + + +def get_tensors(input_shape, output_shape, device): + torch.manual_seed(2023) + npu_dtype = ttl.tensor.DataType.BFLOAT16 + cpu_dtype = torch.bfloat16 + npu_layout = ttl.tensor.Layout.TILE + + torch_input = torch.randint(-100, 100, input_shape, dtype=cpu_dtype) + torch_output = torch.randint(-100, 100, output_shape, dtype=cpu_dtype) + + tt_input = ttl.tensor.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + tt_output = ttl.tensor.Tensor(torch_output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + + return tt_input, tt_output, torch_input + + +@pytest.mark.parametrize( + "input_shape", + ( + ([2, 3, TILE_HEIGHT * 6 - 1, TILE_WIDTH * 7 - 1]), + ([9, 16, TILE_HEIGHT * 13 - 1, TILE_WIDTH * 19 - 1]), + ([4, 3, TILE_HEIGHT * 3 - 1, TILE_WIDTH * 11 - 1]), + ([1, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1]), + ([4, 4, TILE_HEIGHT * 9 - 1, TILE_WIDTH * 12 - 1]), + ([4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 9 - 1]), + ([8, 8, TILE_HEIGHT * 4 - 1, TILE_WIDTH * 4 - 1]), + ), + ids=[ + "2, 3, TILE_HEIGHT * 6 - 1, TILE_WIDTH * 7 - 1", + "9, 16, TILE_HEIGHT * 13 - 1, TILE_WIDTH * 19 - 1", + "4, 3, TILE_HEIGHT * 9 - 1, TILE_WIDTH * 11 - 1", + "1, 1, TILE_HEIGHT-1,TILE_WIDTH - 1", + "4, 4, TILE_HEIGHT * 9 - 1, TILE_WIDTH * 12 - 1", + "4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 9 - 1", + "8, 8, TILE_HEIGHT * 4 - 1, TILE_WIDTH * 4 - 1", + ], +) +@pytest.mark.parametrize( + "dims", + ( + [ + 0, + ], + [ + 1, + ], + ), + ids=["0", "1"], +) +# Support for dim 2,3 in composite_ops +def test_prod_dims(input_shape, dims, device): + output_shape = input_shape.copy() + + for dim in dims: + output_shape[dim] = 1 + + (tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device) + + torch_output = torch.prod(torch_input, dims[0], True) + + cpu_layout = ttl.tensor.Layout.ROW_MAJOR + tt_output_cpu = ( + ttl.operations.primary.prod_nc(tt_input, tt_output, dims=dims) + .cpu() + .to(cpu_layout) + .unpad_from_tile(output_shape) + .to_torch() + ) + + rtol = atol = 0.1 + passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol) + + logger.info(f"Out passing={passing}") + logger.info(f"Output pcc={output_pcc}") + + assert passing diff --git a/tt_eager/tt_dnn/module.mk b/tt_eager/tt_dnn/module.mk index e6a79eaf00aa..1a3ea1c100b5 100644 --- a/tt_eager/tt_dnn/module.mk +++ b/tt_eager/tt_dnn/module.mk @@ -108,6 +108,10 @@ TT_DNN_SRCS = \ tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_sum_nc_impl.cpp \ tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp \ tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_impl/moreh_sum_backward_impl.cpp \ + tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp \ + tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp \ + tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp \ + tt_eager/tt_dnn/op_library/prod/single_core/prod_op_all_single_core.cpp \ tt_eager/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_op.cpp \ tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_h/moreh_mean_h.cpp \ tt_eager/tt_dnn/op_library/moreh_mean/moreh_mean_w/moreh_mean_w.cpp \ diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index 8461ae75e0cf..f1c2192756ea 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -13,6 +13,8 @@ #include "tt_dnn/op_library/math.hpp" #include "tt_dnn/op_library/unpad/unpad_op.hpp" #include "tt_dnn/op_library/complex/complex_ops.hpp" +#include "tt_eager/tt_dnn/op_library/pad/pad_op.hpp" +#include "tt_dnn/op_library/permute/permute_op.hpp" namespace tt { @@ -1517,6 +1519,111 @@ std::vector binary_gt_bw(const Tensor& grad, const Tensor& input, const return operation::decorate_as_composite(__func__, _binary_gt_bw)(grad, input, output_mem_config); } +// Prod +// along a single dimension --> result: grad_data * (y / input ) +std::vector _prod_bw( + const Tensor& grad, const Tensor& input, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config) { + std::vector grad_tensor; + Tensor prod_result = prod(input, all_dimensions, dim, output_mem_config); + if (all_dimensions == true) { + Tensor temp = mul(prod_result, grad, std::nullopt, output_mem_config); // result is stored in the first position + Tensor fill_tensor = tt::numpy::fill_first_val_into_tensor( temp, temp.get_dtype(), temp.get_layout(), temp.device(), output_mem_config); + Tensor all_dimension_result = mul(recip(input, output_mem_config), fill_tensor, std::nullopt, output_mem_config); + grad_tensor.emplace_back(all_dimension_result); + return grad_tensor; + } + // all_dimensions = False + Tensor updated_grad = prod_result; + if (prod_result.get_legacy_shape() != grad.get_legacy_shape()) { + if (dim == 3 || dim == -1) { + std::vector after_permute_dims = {0, 3, 1, 2}; + Tensor required = permute(grad, after_permute_dims, output_mem_config); + const Shape start_index = {0, 0, 0, 0}; + const Shape end_index = { grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[2] - 1}; + Tensor new_unpad_tensor = unpad(required, start_index, end_index); + after_permute_dims = {0, 2, 3, 1}; + updated_grad = permute(new_unpad_tensor, after_permute_dims, output_mem_config); + } else if (dim == 2 || dim == -2) { + std::vector after_permute_dims = {0, 2, 1, 3}; + Tensor required = permute(grad, after_permute_dims, output_mem_config); + const Shape start_index = {0, 0, 0, 0}; + const Shape end_index = { grad.get_legacy_shape()[0] - 1, 0, grad.get_legacy_shape()[1] - 1, grad.get_legacy_shape()[3] - 1}; + Tensor new_unpad_tensor = unpad(required, start_index, end_index); + updated_grad = permute(new_unpad_tensor, after_permute_dims, output_mem_config); + } + } + Tensor reciprocal_input = recip(input, output_mem_config); + Tensor temp = mul(prod_result, (dim == 1 || dim == 0 || dim == -4 || dim == -3) ? grad : updated_grad, std::nullopt, output_mem_config); + if (dim == 3 || dim == -1) { + Tensor grad_result = bcast(reciprocal_input, temp, BcastOpMath::MUL, BcastOpDim::W, output_mem_config); + grad_tensor.emplace_back(grad_result); + return grad_tensor; + } else if (dim == 2 || dim == -2) { + Tensor grad_result = bcast(reciprocal_input, temp, BcastOpMath::MUL, BcastOpDim::H, output_mem_config); + grad_tensor.emplace_back(grad_result); + return grad_tensor; + } else if (dim == 1 || dim == -3) { + Tensor tensor_1_temp = reciprocal_input; + if (reciprocal_input.get_legacy_shape()[1] % 32 != 0) { + const Shape start_index = {0, 0, 0, 0}; + const Shape required_shape = { + reciprocal_input.get_legacy_shape()[0], + reciprocal_input.get_legacy_shape()[1] + (32 - (reciprocal_input.get_legacy_shape()[1] % 32)), + reciprocal_input.get_legacy_shape()[2], + reciprocal_input.get_legacy_shape()[3]}; + tensor_1_temp = pad(reciprocal_input, required_shape, start_index, 0); + } + std::vector after_permute_dims = {0, 2, 3, 1}; + Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config); + Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config); + after_permute_dims = {0, 3, 1, 2}; + Tensor result = permute( bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), after_permute_dims, output_mem_config); + Tensor grad_result = result; + if (reciprocal_input.get_legacy_shape()[1] % 32 != 0) { + const Shape start_index = {0, 0, 0, 0}; + const Shape end_index = { + input.get_legacy_shape()[0] - 1, + input.get_legacy_shape()[1] - 1, + input.get_legacy_shape()[2] - 1, + input.get_legacy_shape()[3] - 1}; + grad_result = unpad(result, start_index, end_index); + } + grad_tensor.emplace_back(grad_result); + return grad_tensor; + } + // dim 0 + Tensor tensor_1_temp = reciprocal_input; + if (reciprocal_input.get_legacy_shape()[0] % 32 != 0) { + const Shape start_index = {0, 0, 0, 0}; + const Shape required_shape = { + reciprocal_input.get_legacy_shape()[0] + (32 - (reciprocal_input.get_legacy_shape()[0] % 32)), + reciprocal_input.get_legacy_shape()[1], + reciprocal_input.get_legacy_shape()[2], + reciprocal_input.get_legacy_shape()[3]}; + tensor_1_temp = pad(reciprocal_input, required_shape, start_index, 0); + } + std::vector after_permute_dims = {3, 1, 2, 0}; + Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config); + Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config); + Tensor result = permute( bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), after_permute_dims, output_mem_config); + Tensor grad_result = result; + if (reciprocal_input.get_legacy_shape()[0] % 32 != 0) { + const Shape start_index = {0, 0, 0, 0}; + const Shape end_index = { + input.get_legacy_shape()[0] - 1, + input.get_legacy_shape()[1] - 1, + input.get_legacy_shape()[2] - 1, + input.get_legacy_shape()[3] - 1}; + grad_result = unpad(result, start_index, end_index); + } + grad_tensor.emplace_back(grad_result); + return grad_tensor; +} +std::vector prod_bw( + const Tensor& grad, const Tensor& input, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _prod_bw)(grad, input, all_dimensions, dim, output_mem_config); +} + // square // result: 2 * input * grad_data std::vector _square_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config) { diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp index a67d32cf05cd..cb1607296c05 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.hpp @@ -220,6 +220,8 @@ std::vector frac_bw(const Tensor& grad, const Tensor& input, const Memor std::vector trunc_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +std::vector prod_bw(const Tensor& grad, const Tensor& input, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + std::vector log_sigmoid_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); std::vector tanhshrink_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index c87a6b1f2e62..9925b9916bfd 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -16,6 +16,10 @@ #include "tt_eager/tensor/tensor_utils.hpp" #include "tt_eager/tt_dnn/op_library/pad/pad_op.hpp" #include "tt_numpy/functions.hpp" +#include "tt_dnn/op_library/prod/prod_nc_op.hpp" +#include "tt_dnn/op_library/prod/prod_op_all.hpp" +#include "tt_dnn/op_library/permute/permute_op.hpp" +#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp" namespace tt { namespace tt_metal { @@ -901,6 +905,80 @@ Tensor xlogy(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& o return operation::decorate_as_composite(__func__, _xlogy)(input_a, input_b, output_mem_config); } +Tensor prod_all(const Tensor& input_a, const MemoryConfig& output_mem_config) { + auto formatted_input_tensor = input_a; + if(formatted_input_tensor.get_layout()==Layout::ROW_MAJOR){ + auto a_pad_shape = AutoFormat::pad_to_tile_shape(input_a.get_legacy_shape(), false, false, true, true); + auto out_shape = input_a.get_legacy_shape(); + out_shape = {out_shape[0], out_shape[1], out_shape[2], out_shape[3]}; + if (!AutoFormat::check_input_tensor_format(input_a, a_pad_shape)) { + formatted_input_tensor = AutoFormat::format_input_tensor(input_a, input_a.device(), a_pad_shape, 1.0, Layout::TILE); + } + } + return tt::operations::primary::prod_all(formatted_input_tensor, output_mem_config); +} + +Tensor prod_nc(const Tensor& temp, int64_t dim, const MemoryConfig& output_mem_config) { + //layout conversion + auto formatted_input_tensor = temp; + if(formatted_input_tensor.get_layout()==Layout::ROW_MAJOR){ + auto a_pad_shape = AutoFormat::pad_to_tile_shape(temp.get_legacy_shape(), false, false, true, true); + auto out_shape = temp.get_legacy_shape(); + out_shape = {out_shape[0], out_shape[1], out_shape[2], out_shape[3]}; + if (!AutoFormat::check_input_tensor_format(temp, a_pad_shape)) { + formatted_input_tensor = AutoFormat::format_input_tensor(temp, temp.device(), a_pad_shape, 1.0, Layout::TILE); + } + } + //Apply prod + std::vector dimension = {(dim == 1 || dim == -3) ? 1 : 0}; + Shape input_shape = formatted_input_tensor.get_legacy_shape(); + Shape required = { ((dim == 1 || dim == -3) ? input_shape[0] : 1), ((dim == 1 || dim == -3) ? 1 : input_shape[1]) , input_shape[2], input_shape[3]}; + return tt::operations::primary::prod_nc(formatted_input_tensor, zeros( required, formatted_input_tensor.get_dtype(), formatted_input_tensor.get_layout(), formatted_input_tensor.device(), output_mem_config), dimension, output_mem_config); +} + +Tensor _prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config) { + if(all_dimensions){ + return tt::tt_metal::prod_all(input_a, output_mem_config); + } + TT_FATAL(dim >= -4 && dim <= 3 && "Dimension out of range (expected to be in range of [-4, 3]"); + Tensor temp = input_a; + //Permute for dim 2,3 + if(dim == 2 || dim == -2){ + std::vector permute_dims = {2, 0, 1, 3}; + temp = permute(input_a, permute_dims, output_mem_config); + }else if(dim == 3 || dim == -1){ + std::vector permute_dims = {3, 0, 1, 2}; + temp = permute(input_a, permute_dims, output_mem_config); + } + Tensor result = tt::tt_metal::prod_nc(temp, dim, output_mem_config); + //Permute and unpad result for dim 2,3 + if(dim == 0 || dim == 1 || dim == -4 || dim == -3){ + return result; + }else if(dim == 2 || dim == -2){ + std::vector after_permute_dims = {1, 2, 0, 3}; + Tensor required = permute(result, after_permute_dims, output_mem_config); + Shape input_shape = input_a.get_legacy_shape(); + const Shape start_index = {0, 0, 0, 0}; + const Shape end_index = {input_shape[0]-1, input_shape[1]-1, 0, input_shape[3]-1}; + return unpad( required, start_index, end_index); + }else{ //dim 3 + //permute + std::vector after_permute_dims = {1, 2, 0, 3}; + Tensor required = permute(result, after_permute_dims, output_mem_config); + //unpad + Shape input_shape = input_a.get_legacy_shape(); + const Shape start_index = {0, 0, 0, 0}; + const Shape end_index = {input_shape[0]-1, input_shape[1]-1, 0, input_shape[2]-1}; + Tensor new_unpad_tensor = unpad( required, start_index, end_index); + //permute back + after_permute_dims = {0, 1, 3, 2}; + return permute(new_unpad_tensor, after_permute_dims, output_mem_config); + } +} +Tensor prod(const Tensor& input_a, bool all_dimensions, int64_t dim, const MemoryConfig& output_mem_config) { + return operation::decorate_as_composite(__func__, _prod)(input_a, all_dimensions, dim, output_mem_config); +} + Tensor _variance_impl( const Tensor& y, const Tensor& mean_y, Tensor& y_minus_mean_y, const MemoryConfig& output_mem_config) { constexpr float correction = 0.0f; diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index ffd734044ccf..2845ed867155 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -178,6 +178,14 @@ Tensor logical_noti( float immediate, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +//prod +Tensor prod( + const Tensor& input_a, + bool all_dimensions = false, + int64_t dim = 0, + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + + /* Returns a new tensor with the signed angles in radians between vectors diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index 2449384cd667..c5d1f17a9b7c 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -170,6 +170,8 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, stri op_init_and_name = {"sign_tile_init();", fmt::format("sign_tile({});", idst)}; break; case UnaryOpType::SQUARE: op_init_and_name = {"square_tile_init();", fmt::format("square_tile({});", idst)}; break; + case UnaryOpType::TILED_PROD: + op_init_and_name = {"tiled_prod_tile_init();", fmt::format("tiled_prod_tile({});", idst)}; break; case UnaryOpType::EQZ: op_init_and_name = {"eqz_tile_init();", fmt::format("eqz_tile({});", idst)}; break; case UnaryOpType::NEZ: diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index b0fd4689ece2..b8a7cedbcd81 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -74,7 +74,8 @@ enum class UnaryOpType { SUB_UNARY_SFPU = 56, MUL_UNARY_SFPU = 57, DIV_UNARY_SFPU = 58, - IDENTITY_UINT32 = 59 + IDENTITY_UINT32 = 59, + TILED_PROD = 60 }; template @@ -270,6 +271,7 @@ constexpr auto recip = make_eltwise_unary{}; constexpr auto relu = make_eltwise_unary{}; constexpr auto relu6 = make_eltwise_unary{}; constexpr auto sigmoid = make_eltwise_unary{}; +constexpr auto tiled_prod = make_eltwise_unary{}; constexpr auto log = make_eltwise_unary{}; constexpr auto tanh = make_eltwise_unary{}; constexpr auto log2 = make_eltwise_unary{}; diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_all.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_all.cpp new file mode 100644 index 000000000000..9375737b75ed --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_all.cpp @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "compute_kernel_api/common.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" +#include "tt_metal/include/compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" +#include "compute_kernel_api/eltwise_unary/negative.h" + +namespace NAMESPACE { +void MAIN { + + constexpr uint32_t num_tiles = get_compile_time_arg_val(0); + constexpr uint32_t per_core_block_dim = get_compile_time_arg_val(1); + + binary_op_init_common(tt::CB::c_in0, tt::CB::c_intermed0, tt::CB::c_out0); + bool last_tile = false; + bool once = true; + for (uint32_t t = 0; t < num_tiles; t++) { + if ( t == (num_tiles - 1)) + { + last_tile = true; + } + cb_reserve_back(tt::CB::c_out0, 1); + for(uint32_t tile_index = 0; tile_index < per_core_block_dim; ++tile_index) { + cb_wait_front(tt::CB::c_in0, 1); + if (once) + { + cb_reserve_back(tt::CB::c_intermed0, 1); + tile_regs_acquire(); + copy_tile_to_dst_init_short(); + copy_tile(tt::CB::c_in0, 0, 0); // copy from c_in[0] to DST[0] + tile_regs_commit(); + tile_regs_wait(); + if constexpr (num_tiles == 1) + pack_tile(0, tt::CB::c_out0); + else + { + pack_tile(0, tt::CB::c_intermed0); + cb_push_back(tt::CB::c_intermed0, 1); + } + tile_regs_release(); + }else { + tile_regs_acquire(); + mul_tiles_init(); + mul_tiles(tt::CB::c_in0, tt::CB::c_intermed0, 0, 0, 0); + tile_regs_commit(); + tile_regs_wait(); + if (last_tile) + { + pack_tile(0, tt::CB::c_out0); + } + else + { + cb_pop_front(tt::CB::c_intermed0, 1); + cb_reserve_back(tt::CB::c_intermed0, 1); + pack_tile(0, tt::CB::c_intermed0); + cb_push_back(tt::CB::c_intermed0, 1); + } + tile_regs_release(); + } + once = false; + cb_pop_front(tt::CB::c_in0, 1); + } + cb_push_back(tt::CB::c_out0, 1); +} +} +} diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_nc.cpp new file mode 100644 index 000000000000..a599e01631ac --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_nc.cpp @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_binary.h" +#include "compute_kernel_api/tile_move_copy.h" + +namespace NAMESPACE { +void MAIN { + const auto num_input_tiles = get_arg_val(0); + const auto num_output_tiles = get_arg_val(1); + + constexpr auto cb_in0 = tt::CB::c_in0; + constexpr auto cb_in1 = tt::CB::c_in1; + constexpr auto cb_out0 = tt::CB::c_out0; + constexpr auto cb_intermed0 = tt::CB::c_intermed0; + constexpr uint32_t onetile = 1; + constexpr uint32_t dst0 = 0; + constexpr uint32_t dst1 = 1; + constexpr uint32_t first_tile = 0; + + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in1); + cb_wait_front(cb_in1, onetile); + + for (uint32_t i = 0; i < num_output_tiles; i++) { + bool enable_reload = false; + for (uint32_t j = 0; j < num_input_tiles; ++j) { + bool last_out = (j == num_input_tiles - 1); + uint32_t cb_add = (enable_reload) ? (cb_intermed0) : (cb_in1); + + cb_wait_front(cb_in0, onetile); + if (enable_reload) { + cb_wait_front(cb_intermed0, onetile); + } + + tile_regs_acquire(); + mul_tiles_init(); + mul_tiles(cb_in0, cb_add, first_tile, first_tile, dst0); + tile_regs_commit(); + + cb_pop_front(cb_in0, onetile); + if (enable_reload) { + cb_pop_front(cb_intermed0, onetile); + } + + uint32_t cb_out = (last_out) ? (cb_out0) : (cb_intermed0); + cb_reserve_back(cb_out, onetile); + tile_regs_wait(); + pack_tile(dst0, cb_out); + tile_regs_release(); + cb_push_back(cb_out, onetile); + enable_reload = true; + } + } +} +} // namespace NAMESPACE diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/dataflow/reader_prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/kernels/dataflow/reader_prod_nc.cpp new file mode 100644 index 000000000000..3c1fdfaf0f63 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/dataflow/reader_prod_nc.cpp @@ -0,0 +1,65 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" +#include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp" + +void kernel_main() { + const auto input_addr = get_arg_val(0); + const auto num_input_tiles = get_arg_val(1); + const auto num_output_tiles = get_arg_val(2); + const auto input_tile_offset = get_arg_val(3); + const auto start_id = get_arg_val(4); + const auto input_is_dram = get_compile_time_arg_val(0) == 1; + const auto HtWt = get_arg_val(6); + const auto CHtWt = get_arg_val(7); + const auto dim = get_compile_time_arg_val(1); + + constexpr uint32_t onetile = 1; + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_in1 = 1; + + union { + float f; + uint32_t u; + } scaler; + scaler.f = 1.0f; + fill_cb_with_value(cb_id_in1, scaler.u); + + uint32_t l1_write_addr_in0; + uint32_t input_tile_bytes = get_tile_size(cb_id_in0); + const auto input_data_format = get_dataformat(cb_id_in0); + const InterleavedAddrGenFast dram_input_addrg = { + .bank_base_address = input_addr, .page_size = input_tile_bytes, .data_format = input_data_format}; + + uint32_t read_tile_id_temp = (dim == 0 ) ? (start_id) : (start_id / HtWt * CHtWt) + (start_id % HtWt); + uint32_t start_tile_id = start_id / HtWt * CHtWt; + uint32_t end_tile_id = start_tile_id + HtWt - 1 ; + uint32_t read_tile_id = read_tile_id_temp; + for (uint32_t i = start_id; i < start_id + num_output_tiles; i++) { + if constexpr (dim == 0){ + read_tile_id = i; + } + for (uint32_t j = 0; j < num_input_tiles; ++j) { + cb_reserve_back(cb_id_in0, onetile); + l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(read_tile_id, dram_input_addrg, l1_write_addr_in0); + noc_async_read_barrier(); + cb_push_back(cb_id_in0, onetile); + read_tile_id += input_tile_offset; + } + if constexpr (dim != 0){ + if(read_tile_id_temp == end_tile_id){ + start_tile_id = start_tile_id + CHtWt; + read_tile_id_temp = start_tile_id; + end_tile_id = read_tile_id_temp + HtWt - 1; + }else{ + read_tile_id_temp = read_tile_id_temp + 1; + } + read_tile_id = read_tile_id_temp; + } + } +} diff --git a/tt_eager/tt_dnn/op_library/prod/kernels/dataflow/utils.hpp b/tt_eager/tt_dnn/op_library/prod/kernels/dataflow/utils.hpp new file mode 100644 index 000000000000..b9c8ecc95b5f --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/kernels/dataflow/utils.hpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void fill_cb_with_value(uint32_t cb_id, uint32_t value, int32_t num_of_elems = 1024) { + cb_reserve_back(cb_id, 1); + volatile tt_l1_ptr std::uint16_t* ptr = (volatile tt_l1_ptr uint16_t*)(get_write_ptr(cb_id)); + for (int j = 0; j < num_of_elems; j++) { + ptr[j] = uint16_t(value >> 16); + } + cb_push_back(cb_id, 1); +} diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp new file mode 100644 index 000000000000..78781395077d --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc/prod_nc.cpp @@ -0,0 +1,201 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_eager/tt_dnn/op_library/work_split.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + +namespace tt { +using namespace constants; +namespace operations { + +namespace primary { + +operation::ProgramWithCallbacks prod_nc_format(const Tensor &input, const Tensor &output, int64_t dim) { + TT_ASSERT(dim == 0 || dim == 1); + + //////////////////////////////////////////////////////////////////////////// + // Device Setup + //////////////////////////////////////////////////////////////////////////// + auto *device = input.device(); + auto program = Program(); + + //////////////////////////////////////////////////////////////////////////// + // Parameters Setup + //////////////////////////////////////////////////////////////////////////// + const auto cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); + const auto single_tile_size = detail::TileSize(cb_data_format); + + const auto &input_shape = input.get_legacy_shape(); + const auto &input_shape_without_padding = input_shape.without_padding(); + + const auto N = input_shape[0]; + const auto C = input_shape[1]; + const auto Ht = input_shape[2] / TILE_HEIGHT; + const auto Wt = input_shape[3] / TILE_WIDTH; + const auto HtWt = Ht * Wt; + const auto CHtWt = C * Ht * Wt; + const auto num_reduce_input_tile = input_shape[dim]; + const auto input_tile_offset = (dim == 0) ? (CHtWt) : (HtWt); + const auto num_output_tiles = output.volume() / TILE_HW; + + log_debug(LogTest, "N {} C {} Ht {} Wt {}", N, C, Ht, Wt); + log_debug( + LogTest, + "dim {} num_reduce_input_tile {} input_tile_offset {}, num_output_tiles {}", + dim, + num_reduce_input_tile, + input_tile_offset, + num_output_tiles); + + //////////////////////////////////////////////////////////////////////////// + // Core Setup + //////////////////////////////////////////////////////////////////////////// + CoreGridDesc core_grid(device); + const auto num_cores_y = core_grid.y_; + CoreCoord core_grid_coord = {core_grid.x_, num_cores_y}; + + const uint32_t in0_t = 2; // input + const uint32_t in1_t = 1; // zero + const uint32_t intermed0_t = 1; // accumulated sum + const uint32_t out0_t = 2; // output + const auto + [num_cores_to_be_used, + all_cores, + core_group_1, + core_group_2, + num_cols_per_core_group_1, + num_cols_per_core_group_2] = tt_metal::split_work_to_cores(core_grid_coord, num_output_tiles); + + //////////////////////////////////////////////////////////////////////////// + // CircularBuffer Setup + //////////////////////////////////////////////////////////////////////////// + CreateCircularBuffer( + program, + all_cores, + cb_data_format, + { + {CB::c_in0, in0_t}, // input + {CB::c_in1, in1_t}, // zero + {CB::c_intermed0, intermed0_t}, // accumulated sum + {CB::c_out0, out0_t}, // output + }); + + //////////////////////////////////////////////////////////////////////////// + // DataMovementKernel SetUp + //////////////////////////////////////////////////////////////////////////// + + tt_metal::Buffer *input_buffer_type = input.buffer(); + bool input_is_dram = input_buffer_type->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = {(std::uint32_t) input_is_dram, static_cast(dim)}; + + tt_metal::Buffer *output_buffer_type = output.buffer(); + constexpr uint32_t cb_id_out = 16; + bool output_is_dram = output_buffer_type->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = {(std::uint32_t) cb_id_out, (std::uint32_t) output_is_dram}; + + const auto reader_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/dataflow/reader_prod_nc.cpp"; + const auto writer_kernel_file = "tt_eager/tt_dnn/kernels/dataflow/writer_unary_interleaved_start_id.cpp"; + const auto reader_kernel_id = CreateReadKernel(program, reader_kernel_file, all_cores, reader_compile_time_args); + const auto writer_kernel_id = CreateWriteKernel(program, writer_kernel_file, all_cores, writer_compile_time_args); + + //////////////////////////////////////////////////////////////////////////// + // ComputeKernel SetUp + //////////////////////////////////////////////////////////////////////////// + const std::vector compute_args_group_1{num_cols_per_core_group_1}; + std::map compute_defines; + const auto compute_kernel_file = "tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_nc.cpp"; + const auto compute_kernel_1_id = CreateComputeKernel( + program, compute_kernel_file, {core_group_1, num_cols_per_core_group_1, compute_args_group_1}, compute_defines); + + std::optional compute_kernel_2_id = std::nullopt; + if (!core_group_2.ranges().empty()) { + const std::vector compute_args_group_2{num_cols_per_core_group_2}; + compute_kernel_2_id = CreateComputeKernel( + program, + compute_kernel_file, + {core_group_2, num_cols_per_core_group_2, compute_args_group_2}, + compute_defines); + } + + //////////////////////////////////////////////////////////////////////////// + // RuntimeArgs SetUp + //////////////////////////////////////////////////////////////////////////// + for (uint32_t i = 0, tile_offset = 0; i < num_cores_to_be_used; ++i) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + + uint32_t num_tiles_per_core; + if (core_group_1.core_coord_in_core_ranges(core)) { + num_tiles_per_core = num_cols_per_core_group_1; + } else if (core_group_2.core_coord_in_core_ranges(core)) { + num_tiles_per_core = num_cols_per_core_group_2; + } else { + TT_THROW("Core not in specified core ranges."); + } + + SetRuntimeArgs( + program, + reader_kernel_id, + core, + {input.buffer()->address(), + num_reduce_input_tile, + num_tiles_per_core, + input_tile_offset, + tile_offset, + static_cast(is_dram(input)), + HtWt, + CHtWt, + static_cast(dim) + }); + + SetRuntimeArgs( + program, + writer_kernel_id, + core, + {output.buffer()->address(), num_tiles_per_core, tile_offset, static_cast(is_dram(output))}); + + if (core_group_1.core_coord_in_core_ranges(core)) { + SetRuntimeArgs(program, compute_kernel_1_id, core, {num_reduce_input_tile, num_tiles_per_core}); + } else if (core_group_2.core_coord_in_core_ranges(core)) { + TT_ASSERT(compute_kernel_2_id.has_value()); + SetRuntimeArgs(program, compute_kernel_2_id.value(), core, {num_reduce_input_tile, num_tiles_per_core}); + } else { + TT_ASSERT(false, "Core not in specified core ranges."); + } + tile_offset += num_tiles_per_core; + } + + auto override_runtime_arguments_callback = [reader_kernel_id, writer_kernel_id, num_cores_to_be_used, num_cores_y]( + const void *operation, + const Program &program, + const std::vector &input_tensors, + const std::vector> &, + const std::vector &output_tensors) { + const auto *input_buffer = input_tensors.at(0).buffer(); + const auto *output_buffer = input_tensors.at(1).buffer(); + for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + { + auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + runtime_args[0] = input_buffer->address(); + SetRuntimeArgs(program, reader_kernel_id, core, runtime_args); + } + + { + auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + runtime_args[0] = output_buffer->address(); + SetRuntimeArgs(program, writer_kernel_id, core, runtime_args); + } + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp new file mode 100644 index 000000000000..fc8712975d06 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp @@ -0,0 +1,121 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/prod/prod_nc_op.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" + +namespace tt { +using namespace constants; +namespace operations { +namespace primary { + +//////////////////////////////////////////////////////////////////////////// +// Prod +//////////////////////////////////////////////////////////////////////////// +void Prod::validate(const std::vector& inputs) const { + TT_FATAL((dim >= 0 && dim <= 3), "dim should be 0 - 3"); + const auto& input = inputs.at(0); + const auto& output = inputs.at(1); + + auto input_shape = input.get_legacy_shape(); + TT_FATAL((input_shape.rank() == 4), "rank should be 4"); + const auto& output_shape = output.get_legacy_shape(); + auto input_shape_wo_padding = input.get_legacy_shape().without_padding(); + const auto& output_shape_wo_padding = output.get_legacy_shape().without_padding(); + + if (dim == 0 || dim == 1) { + input_shape[dim] = 1; + input_shape_wo_padding[dim] = 1; + } + + for (int i = 0; i < input_shape.rank(); ++i) { + TT_FATAL(input_shape[i] == output_shape[i]); + // TT_FATAL(input_shape_wo_padding[i] == output_shape_wo_padding[i]); + } +} + +std::vector Prod::create_output_tensors(const std::vector& inputs) const { + // Inplace + return {}; +} + +std::vector Prod::compute_output_shapes(const std::vector& inputs) const { + // Inplace + return {}; + +} + +operation::ProgramWithCallbacks Prod::create_program( + const std::vector& inputs, std::vector& outputs) const { + auto& input = inputs.at(0); + auto& output = inputs.at(1); + + return prod_nc_format(input, output, dim); +} + +inline Shape compute_output_shape(const Shape& input_shape, const int64_t& dim) { + auto output_shape = input_shape; + auto padding = output_shape.padding(); + switch (dim) { + case 0: + case 1: output_shape[dim] = 1; + break; + } + + return {Shape(output_shape, padding)}; +} + +inline Tensor create_output_tensor( + const Tensor& input_tensor, const Shape& output_shape, const MemoryConfig& mem_config) { + TT_ASSERT(input_tensor.storage_type() == StorageType::DEVICE); + return create_device_tensor(output_shape, input_tensor.get_dtype(), Layout::TILE, input_tensor.device(), mem_config); +} + +// output as arg +Tensor prod_(const Tensor& input, const Tensor& output, const int64_t& dim) { + operation::run(Prod{.dim = dim}, {input, output}); + return output; +} + +// output creation inside +Tensor prod_(const Tensor& input, const int64_t& dim, const MemoryConfig& mem_config) { + const auto& input_shape = input.get_legacy_shape(); + const auto& output_shape = compute_output_shape(input_shape, dim); + auto output = create_output_tensor(input, output_shape, mem_config); + + const auto& output_shape_wo_padding = output.get_legacy_shape().without_padding(); + operation::run(Prod{.dim = dim}, {input, output}); + return output; +} + +Tensor prod_nc( + const Tensor& input, + const Tensor& output, + std::vector& dims, + const MemoryConfig& output_mem_config) { + // reduce for all dims + if (dims.empty()) { + dims = {0, 1, 2, 3}; + } + + std::vector sorted_dims = dims; + std::sort(sorted_dims.begin(), sorted_dims.end()); + + auto temp_input = input; + for (uint32_t i = dims.size() - 1; i > 0; i--) { + log_debug(LogTest, "{}:{} dim {}", __func__, __LINE__, sorted_dims[i]); + auto temp_output = prod_(temp_input, sorted_dims[i], output_mem_config); + temp_input = temp_output; + } + log_debug(LogTest, "{}:{} dim {}", __func__, __LINE__, sorted_dims.front()); + prod_(temp_input, output, sorted_dims.front()); + return output; +} + +} // namespace primary +} // namespace operations +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp new file mode 100644 index 000000000000..3d879c3071b4 --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_nc_op.hpp @@ -0,0 +1,53 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#include "tt_dnn/op_library/run_operation.hpp" +#include "tt_eager/tensor/tensor.hpp" + +namespace tt { + +namespace operations { + +namespace primary { + +using namespace tt_metal; + +struct Prod { + int64_t dim; + void validate(const std::vector &inputs) const; + std::vector compute_output_shapes(const std::vector &inputs) const; + std::vector create_output_tensors(const std::vector &inputs) const; + operation::ProgramWithCallbacks create_program( + const std::vector &inputs, std::vector &outputs) const; + stl::reflection::Attributes attributes() const; + static constexpr auto attribute_names = std::make_tuple("dim"); + const auto attribute_values() const { return std::make_tuple(std::cref(this->dim)); } +}; + +operation::ProgramWithCallbacks prod_nc_format(const Tensor &input, const Tensor &output, int64_t dim); + +Tensor prod_( + const Tensor &input, + std::optional> output, + const int64_t &dim, + const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +Tensor prod_nc( + const Tensor &input, + const Tensor &output, + std::vector &dims, + const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + +} // namespace primary + +} // namespace operations + +} // namespace tt diff --git a/tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp b/tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp new file mode 100644 index 000000000000..385321f1431f --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_op_all.cpp @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/prod/prod_op_all.hpp" +#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" + +#include +#include + +#include "tt_metal/common/constants.hpp" +#include +#include "tt_metal/host_api.hpp" +#include "tt_metal/tools/profiler/op_profiler.hpp" + +namespace tt { +using namespace constants; +namespace operations { +namespace primary { + +void Prod_op::validate(const std::vector& input_tensors) const { + const auto& input_tensor_a = input_tensors.at(0); + TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands need to be on device!"); + TT_FATAL(input_tensor_a.buffer() != nullptr , "Operands need to be allocated in buffers on device!"); + TT_FATAL((input_tensor_a.get_layout() == Layout::TILE), "Input Layout must be tilized"); + TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(input_tensor_a.get_dtype() == DataType::BFLOAT16); +} + +std::vector Prod_op::compute_output_shapes(const std::vector& input_tensors) const { + const auto& input_tensor = input_tensors.at(0); + return {input_tensor.get_legacy_shape()}; +} + +std::vector Prod_op::create_output_tensors(const std::vector& input_tensors) const { +const auto& input_tensor = input_tensors.at(0); + return operation::generic_create_output_tensors(*this, input_tensors, input_tensor.get_dtype(), Layout::TILE, this->output_mem_config); +} + +operation::ProgramWithCallbacks Prod_op::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + const auto& input_tensor_a = input_tensors.at(0); + auto& output_tensor = output_tensors.at(0); + return prod_single_core(input_tensor_a, output_tensor); +} + +Tensor prod_all(const Tensor& input, const MemoryConfig& output_mem_config ) { + Tensor result = tiled_prod( operation::run(Prod_op{.output_mem_config = output_mem_config}, {input}).at(0), output_mem_config); + auto arch_env = detect_arch(); + if(arch_env == tt::ARCH::WORMHOLE_B0){ + return tt::numpy::prod_result_computation_WH_B0(result, result.get_dtype(), result.get_layout(), result.device(), output_mem_config); + } + //else --> GS Arch + return tt::numpy::prod_result_computation_GS(result, result.get_dtype(), result.get_layout(), result.device(), output_mem_config); + return operation::run(Prod_op{.output_mem_config = output_mem_config}, {input}).at(0); +} + +} +} +} diff --git a/tt_eager/tt_dnn/op_library/prod/prod_op_all.hpp b/tt_eager/tt_dnn/op_library/prod/prod_op_all.hpp new file mode 100644 index 000000000000..3fa1c497174f --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/prod_op_all.hpp @@ -0,0 +1,45 @@ +/* + * SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once +#include "tensor/tensor.hpp" +#include "tt_dnn/op_library/run_operation.hpp" + +namespace tt { + +namespace operations{ + +namespace primary{ + +/* + * prod product + */ + +struct Prod_op { + const MemoryConfig output_mem_config; + const DataType output_dtype; // TODO: Uplift output_dtype as an option for general dot/bmm + void validate(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector &input_tensors, std::vector &output_tensors) const; + static constexpr auto attribute_names = + std::make_tuple("output_mem_config", "output_dtype"); + const auto attribute_values() const { + return std::make_tuple( + std::cref(this->output_mem_config), std::cref(this->output_dtype)); + } +}; + +operation::ProgramWithCallbacks prod_single_core(const Tensor &input_tensor_a, const Tensor &output_tensor); + +Tensor prod_all( + const Tensor &input, + const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); +} + +} +} diff --git a/tt_eager/tt_dnn/op_library/prod/single_core/prod_op_all_single_core.cpp b/tt_eager/tt_dnn/op_library/prod/single_core/prod_op_all_single_core.cpp new file mode 100644 index 000000000000..86bb6cbbe81a --- /dev/null +++ b/tt_eager/tt_dnn/op_library/prod/single_core/prod_op_all_single_core.cpp @@ -0,0 +1,137 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "tt_dnn/op_library/prod/prod_op_all.hpp" +#include "tt_metal/common/constants.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" + + +namespace tt { +using namespace constants; +namespace operations { +namespace primary { + + operation::ProgramWithCallbacks prod_single_core(const Tensor &a, const Tensor& output) + { + + Program program{}; + + CoreRange core({0, 0}, {0, 0}); + + tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); + uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format); + + uint32_t num_tiles = a.volume() / TILE_HW; + + // This should allocate a DRAM buffer on the device + tt_metal::Device *device = a.device(); + + uint32_t src0_cb_index = 0; + uint32_t num_input_tiles = 2; + tt_metal::CircularBufferConfig cb_src0_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, single_tile_size); + auto cb_src0 = tt_metal::CreateCircularBuffer(program, core, cb_src0_config); + + tt_metal::CircularBufferConfig cb_inter_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{tt::CB::c_intermed0, cb_data_format}}) + .set_page_size(tt::CB::c_intermed0, single_tile_size); + auto cb_interm = tt_metal::CreateCircularBuffer(program, core, cb_inter_config); + + uint32_t output_cb_index = 16; // output operands start at index 16 + uint32_t num_output_tiles = 2; + tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{output_cb_index, cb_data_format}}) + .set_page_size(output_cb_index, single_tile_size); + auto cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); + + auto src_buffer = a.buffer(); + auto dst_buffer = output.buffer(); + + bool src_is_dram = src_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector reader_compile_time_args = {(uint32_t)src_is_dram}; + bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; + std::vector writer_compile_time_args = { + (std::uint32_t) output_cb_index, + (std::uint32_t) dst_is_dram + }; + + tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/kernels/dataflow/reader_unary_interleaved_start_id.cpp", + core, + tt_metal::ReaderDataMovementConfig{reader_compile_time_args}); + + tt_metal::KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/kernels/dataflow/writer_unary_interleaved_start_id.cpp", + core, + tt_metal::WriterDataMovementConfig{writer_compile_time_args}); + + vector compute_kernel_args = { + num_tiles, // per_core_block_cnt + 1 // per_core_block_size + }; + + bool fp32_dest_acc_en = false; + bool math_approx_mode = true; + auto eltwise_unary_kernel_id = tt_metal::CreateKernel( + program, + "tt_eager/tt_dnn/op_library/prod/kernels/compute/prod_all.cpp", + core, + tt_metal::ComputeConfig{ + .math_fidelity = MathFidelity::HiFi4, + .fp32_dest_acc_en = fp32_dest_acc_en, + .math_approx_mode = math_approx_mode, + .compile_args = compute_kernel_args + } + ); + + SetRuntimeArgs( + program, + unary_reader_kernel_id, + core, + { + src_buffer->address(), + num_tiles, 0 + } + ); + + SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + { + dst_buffer->address(), + num_tiles, 0 + } + ); + + auto override_runtime_args_callback = [unary_reader_kernel_id, unary_writer_kernel_id]( + const Program &program, + const std::vector& input_buffers, + const std::vector& output_buffers + ) { + + auto src_buffer = input_buffers.at(0); + + auto dst_buffer = output_buffers.at(0); + + CoreCoord core = {0, 0}; + + { + auto &runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); + runtime_args[0] = src_buffer->address(); + } + + { + auto &runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args[0] = dst_buffer->address(); + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + +} +} +} diff --git a/tt_eager/tt_lib/csrc/operations/primary/module.hpp b/tt_eager/tt_lib/csrc/operations/primary/module.hpp index cc3deee5bfd0..4258bac3494c 100644 --- a/tt_eager/tt_lib/csrc/operations/primary/module.hpp +++ b/tt_eager/tt_lib/csrc/operations/primary/module.hpp @@ -28,6 +28,7 @@ #include "tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" #include "tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp" #include "tt_dnn/op_library/softmax/softmax_op.hpp" +#include "tt_dnn/op_library/prod/prod_nc_op.hpp" #include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp" #include "tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_op.hpp" #include "tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp" @@ -39,6 +40,7 @@ #include "tt_dnn/op_library/moreh_mean/moreh_mean_op.hpp" #include "tt_dnn/op_library/moreh_mean_backward/moreh_mean_backward_op.hpp" #include "tt_dnn/op_library/moreh_getitem/moreh_getitem_op.hpp" +#include "tt_dnn/op_library/prod/prod_op_all.hpp" namespace py = pybind11; @@ -479,6 +481,14 @@ void py_module(py::module& m_primary) { Performs a rmsnorm(a+b)*gamma + beta operation. )doc"); + //prod along all dimensions + m_primary.def( + "prod_all", + &prod_all, + py::arg("input").noconvert(), + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + "Computes prod along along dimensions of the tensor."); + // moreh_adam m_primary.def( "moreh_adam", @@ -837,6 +847,16 @@ void py_module(py::module& m_primary) { py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, "Performs sum operation. Returns an output tensor."); + m_primary.def( + "prod_nc", + &prod_nc, + py::arg("input").noconvert(), + py::arg("output").noconvert(), + py::kw_only(), + py::arg("dims").noconvert() = std::vector(), + py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + "Performs product operation. Returns an output tensor."); + m_primary.def( "moreh_sum_backward", &moreh_sum_backward, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp index bde7cf041550..3288535cb62c 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp @@ -1699,6 +1699,25 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); + m_tensor.def("prod_bw", &tt::tt_metal::prod_bw, + py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("all_dimensions") , py::arg("dim") , py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + Performs backward operations for prod on ``input_a`` along ``all_dimensions`` or a particular ``dim``. + If ``all_dimensions`` is set to ``true``, irrespective of given dimension it will perform backward prod for all dimensions. + + Input tensor must have BFLOAT16 data type. + + Output tensors will have BFLOAT16 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "input", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "all_dimensions", "Consider all dimension (ignores ``dim`` param)", "bool", "", "Yes" + "dim", "Dimension to perform prod", "int", "", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + )doc"); + m_tensor.def("log_sigmoid_bw", &tt::tt_metal::log_sigmoid_bw, py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( Performs backward operations for log_sigmoid ``input`` tensors with given ``grad``. diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index 78d3551bd3ab..f16bf681ae84 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -207,6 +207,23 @@ namespace tt::tt_metal::detail{ R"doc(Applies the Gated Linear Units (GLU) function to the elements of the input tensor ``{0}`` split along dim ``{1}``.)doc", R"doc(dimension to split)doc" ); + m_tensor.def("prod", &prod, + py::arg("input").noconvert(), py::arg("all_dimensions") = false, py::arg("dim") = 0, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + Computes the prod function along specified ``dim`` or all dimensions on the ``input`` tensor. + If ``all_dimensions`` is set to ``true`` irrespective of given dimension it will prod along all dimensions. + + Input tensor must have BFLOAT16 data type. + + Output tensor will have BFLOAT16 data type. + + .. csv-table:: + :header: "Argument", "Description", "Data type", "Valid range", "Required" + + "input", "Tensor prod is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" + "all_dimensions", "Consider all dimension (ignores ``dim`` param)", "bool", "default to false", "No" + "dim", "Dimension to perform prod", "int", "default to 0", "Yes" + "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + )doc"); detail::bind_unary_op_with_param( m_tensor, "geglu", &geglu, py::arg("dim") = -1, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp index b508bbe8286a..42a78f305417 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp @@ -59,6 +59,7 @@ namespace tt::tt_metal::detail { detail::bind_unary_op(m_tensor, "isfinite", isfinite, R"doc(Returns boolean tensor that is True where input tensor ``{0}``, is finite and False elsewhere.)doc"); detail::bind_unary_op(m_tensor, "isinf", isinf, R"doc(Returns boolean tensor that is True where input tensor ``{0}``, is infinite and False elsewhere.)doc"); detail::bind_unary_op(m_tensor, "isposinf", isposinf, R"doc(Returns each element of input tensor ``{0}``, is positive infinity or not.)doc"); + detail::bind_unary_op(m_tensor, "tiled_prod", tiled_prod, R"doc(Performs tile-wise multiplication on input tensor ``{0}`` and store the result in the last tile of the input tensor.)doc"); detail::bind_unary_op(m_tensor, "isneginf", isneginf, R"doc(Returns each element of input tensor ``{0}``, is negative infinity or not.)doc"); detail::bind_unary_op(m_tensor, "isnan", isnan, R"doc(Returns boolean tensor that is True where tensor ``{0}``, is NaN and False elsewhere.)doc"); detail::bind_unary_op(m_tensor, "sign", sign, R"doc(Returns tensor with the elementwise signum of the input tensor ``{0}``.)doc"); diff --git a/tt_eager/tt_numpy/functions.hpp b/tt_eager/tt_numpy/functions.hpp index 66c30ab4a7f6..3e0dd5fc0679 100644 --- a/tt_eager/tt_numpy/functions.hpp +++ b/tt_eager/tt_numpy/functions.hpp @@ -354,6 +354,120 @@ static Tensor index_all( return output; } +template +static Tensor fill_first_val_into_tensor(const Tensor& input_tensor, DataType data_type, + const Layout layout , Device * device = nullptr, + const MemoryConfig& output_mem_config = MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + const Shape& s_a = input_tensor.get_legacy_shape(); + auto owned_buffer = tt_metal::owned_buffer::create(tt_metal::compute_volume(s_a)); //ouput + auto device_buffer = input_tensor.device_buffer(); + uint32_t size_in_bytes = device_buffer->size(); + vector data_vec; + const char *TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); + if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { + data_vec.resize(size_in_bytes / sizeof(T)); + tt::tt_metal::tensor_impl::read_data_from_device_buffer(input_tensor.device()->command_queue(), device_buffer, data_vec.data(), true); + } else { + tt::tt_metal::tensor_impl::read_data_from_device_buffer(device_buffer, data_vec); + } + auto input_buffer = owned_buffer::create(std::move(data_vec)); + const Shape input_tensor_strides = input_tensor.strides(); + for(uint32_t i = 0; i < tt_metal::compute_volume(s_a); i++) { + owned_buffer[i] = input_buffer[0]; + } + auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, layout).to(layout); + if (device != nullptr) { + output = output.to(device, output_mem_config); + } + return output; +} + +template +static Tensor prod_result_computation_GS(const Tensor& input_tensor, DataType data_type, + const Layout layout , Device * device = nullptr, + const MemoryConfig& output_mem_config = MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + const Shape& s_a = input_tensor.get_legacy_shape(); + auto owned_buffer = tt_metal::owned_buffer::create(tt_metal::compute_volume(s_a)); //ouput + auto device_buffer = input_tensor.device_buffer(); + uint32_t size_in_bytes = device_buffer->size(); + vector data_vec; + const char *TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); + if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { + data_vec.resize(size_in_bytes / sizeof(T)); + tt::tt_metal::tensor_impl::read_data_from_device_buffer(input_tensor.device()->command_queue(), device_buffer, data_vec.data(), true); + } else { + tt::tt_metal::tensor_impl::read_data_from_device_buffer(device_buffer, data_vec); + } + auto input_buffer = owned_buffer::create(std::move(data_vec)); + const Shape input_tensor_strides = input_tensor.strides(); + auto result = static_cast(1.0f); + for(uint32_t i = s_a[0]-1; i < s_a[0]; i++) { + for(int32_t j = s_a[1]-1; j < s_a[1]; j++) { + for(int32_t k = s_a[2]-32; k < s_a[2]; k++) { //access last tile + for(int32_t l = s_a[3]-32; l < s_a[3]; l++) { + auto input_index = l + input_tensor_strides[2] * k + input_tensor_strides[1] * j + input_tensor_strides[0] * i; + if(k>=s_a[2]-2 && l>=s_a[3]-32){ //to access 2*32 in TILE layout + result = result * static_cast(input_buffer[input_index]); + owned_buffer[input_index] = static_cast(0.0f); + }else{ + owned_buffer[input_index] = static_cast(0.0f); + } + } + } + } + } + owned_buffer[0] = result; //store the result at the first position of the tensor,and the rest of the values as 0.0f + auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, layout).to(layout); + if (device != nullptr) { + output = output.to(device, output_mem_config); + } + return output; +} + +template +static Tensor prod_result_computation_WH_B0(const Tensor& input_tensor, DataType data_type, + const Layout layout , Device * device = nullptr, + const MemoryConfig& output_mem_config = MemoryConfig{.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { + const Shape& s_a = input_tensor.get_legacy_shape(); + auto owned_buffer = tt_metal::owned_buffer::create(tt_metal::compute_volume(s_a)); //ouput + auto device_buffer = input_tensor.device_buffer(); + uint32_t size_in_bytes = device_buffer->size(); + vector data_vec; + const char *TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); + if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { + data_vec.resize(size_in_bytes / sizeof(T)); + tt::tt_metal::tensor_impl::read_data_from_device_buffer(input_tensor.device()->command_queue(), device_buffer, data_vec.data(), true); + } else { + tt::tt_metal::tensor_impl::read_data_from_device_buffer(device_buffer, data_vec); + } + auto input_buffer = owned_buffer::create(std::move(data_vec)); + const Shape input_tensor_strides = input_tensor.strides(); + auto result = static_cast(1.0f); + // need to access the last 4 rows and alternating columns of index 17 ,19, 21, 23, 25, 27, 29, 31 + for(uint32_t i = s_a[0]-1; i < s_a[0]; i++) { + for(int32_t j = s_a[1]-1; j < s_a[1]; j++) { + for(int32_t k = s_a[2]-32; k < s_a[2]; k++) { //access last tile + for(int32_t l = s_a[3]-32; l < s_a[3]; l++) { + auto input_index = l + input_tensor_strides[2] * k + input_tensor_strides[1] * j + input_tensor_strides[0] * i; + if(k>=s_a[2]-4 && (l==s_a[3]-15 || l==s_a[3]-13 || l==s_a[3]-11 || l==s_a[3]-9 || l==s_a[3]-7 || l==s_a[3]-5 || l==s_a[3]-3 || l==s_a[3]-1)){ //to access 4*16 elements placed alternatively starting from index 17W in TILE layout + result = result * static_cast(input_buffer[input_index]); + owned_buffer[input_index] = static_cast(0.0f); + }else{ + owned_buffer[input_index] = static_cast(0.0f); + } + } + } + } + } + owned_buffer[0] = result; //store the result at the first position of the tensor,and the rest of the values as 0.0f + auto output = Tensor(OwnedStorage{owned_buffer}, s_a, data_type, layout).to(layout); + if (device != nullptr) { + output = output.to(device, output_mem_config); + } + return output; +} + + template static Tensor index_channel( const Shape& shape, diff --git a/tt_metal/common/bfloat16.hpp b/tt_metal/common/bfloat16.hpp index f6dde378b1c8..a6c58ef16256 100644 --- a/tt_metal/common/bfloat16.hpp +++ b/tt_metal/common/bfloat16.hpp @@ -70,6 +70,9 @@ class bfloat16 { bool operator!=(const bfloat16 rhs) const { return not (*this == rhs); } + bfloat16 operator*(const bfloat16 rhs) const { + return bfloat16(this->to_float() * rhs.to_float()); + } }; inline ostream& operator<<(ostream& os, const bfloat16& bfp16) diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h index 8162c89d5e72..501b7a307a23 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h @@ -11,6 +11,7 @@ #include "llk_math_eltwise_unary_sfpu_heaviside.h" #include "llk_math_eltwise_unary_sfpu_power.h" #include "llk_math_eltwise_unary_sfpu_rsqrt.h" +#include "llk_math_eltwise_unary_sfpu_tiled_prod.h" #include "llk_math_eltwise_unary_sfpu_sigmoid.h" #include "llk_math_eltwise_unary_sfpu_sign.h" #include "llk_math_eltwise_unary_sfpu_signbit.h" diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_tiled_prod.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_tiled_prod.h new file mode 100644 index 000000000000..de91583c898b --- /dev/null +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/ckernel_sfpu_tiled_prod.h @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_tiled_prod() +{ + vFloat result = 1.0f; + #pragma GCC unroll 0 + for (int d = 0; d < ITERATIONS; d++) { + vFloat v = dst_reg[0]; + result *= v; + dst_reg[0] = result; + dst_reg++; + } + vFloat v = dst_reg[0]; + result *= v; + dst_reg[0] = result; + dst_reg++; +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_tiled_prod.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_tiled_prod.h new file mode 100644 index 000000000000..929a71a8815c --- /dev/null +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_tiled_prod.h @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_unary_sfpu_init.h" +#include "llk_math_eltwise_unary_sfpu_0_param.h" +#include "ckernel_sfpu_tiled_prod.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_tiled_prod_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_tiled_prod(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_0_param + (ckernel::sfpu::calculate_tiled_prod, + ckernel::sfpu::calculate_tiled_prod, + dst_index, vector_mode); +} + +} diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu_types.h index bf23a084b6df..5e7a3526ae2c 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu_types.h @@ -60,5 +60,6 @@ enum SfpuType { silu, mask, negative, + tiled_prod, unused, }; diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h index 23972fb69f9e..74f6d1a0a310 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h @@ -14,6 +14,7 @@ #include "llk_math_eltwise_unary_sfpu_max.h" #include "llk_math_eltwise_unary_sfpu_power.h" #include "llk_math_eltwise_unary_sfpu_rsqrt.h" +#include "llk_math_eltwise_unary_sfpu_tiled_prod.h" #include "llk_math_eltwise_unary_sfpu_sigmoid.h" #include "llk_math_eltwise_unary_sfpu_sign.h" #include "llk_math_eltwise_unary_sfpu_signbit.h" diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_tiled_prod.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_tiled_prod.h new file mode 100644 index 000000000000..2e1fb1201c5e --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_tiled_prod.h @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "noc_nonblocking_api.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_tiled_prod() +{ + vFloat result = 1.0f; + #pragma GCC unroll 8 + for (int d = 0; d < ITERATIONS; d++) { + vFloat v = dst_reg[0]; + result *= v; + dst_reg[0] = result; + dst_reg++; + } + vFloat v = dst_reg[0]; + result *= v; + dst_reg[0] = result; + dst_reg++; +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_tiled_prod.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_tiled_prod.h new file mode 100644 index 000000000000..3891d688bdaa --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_tiled_prod.h @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_unary_sfpu_init.h" +#include "llk_math_eltwise_unary_sfpu_0_param.h" +#include "ckernel_sfpu_tiled_prod.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_tiled_prod_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_tiled_prod(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_0_param + (ckernel::sfpu::calculate_tiled_prod, + ckernel::sfpu::calculate_tiled_prod, + dst_index, vector_mode); +} + +} diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index e2105c90b676..83f94fa716df 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -70,5 +70,6 @@ enum SfpuType { topk_local_sort, topk_merge, topk_rebuild, + tiled_prod, unused, }; diff --git a/tt_metal/include/compute_kernel_api.h b/tt_metal/include/compute_kernel_api.h index b1ea83c175cc..d2d3196d0847 100644 --- a/tt_metal/include/compute_kernel_api.h +++ b/tt_metal/include/compute_kernel_api.h @@ -357,6 +357,30 @@ ALWI void lez_tile_init() { MATH(( llk_math_eltwise_unary_sfpu_lez_init() )); } +/** + * Performs element-wise multiplication on each row of a tile. + * The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True | + */ +ALWI void tiled_prod_tile(uint32_t idst) { + MATH(( llk_math_eltwise_unary_sfpu_tiled_prod(idst) )); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void tiled_prod_tile_init() { + MATH(( llk_math_eltwise_unary_sfpu_tiled_prod_init() )); +} + + /** * Will store in the output of the compute core True if each element is greater than zero. * The DST register buffer must be in acquired state via *acquire_dst* call.