diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_comp_init.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_comp_init.py new file mode 100644 index 00000000000..294a6e892b1 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_comp_init.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( + data_gen_with_range, + data_gen_with_range_dtype, +) +from models.utility_functions import is_grayskull, skip_for_blackhole + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([64, 64])), + (torch.Size([2, 32, 32])), + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "mem_configs", + ( + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), + ), +) +@pytest.mark.parametrize("out_dtype", (ttnn.uint32, ttnn.uint16)) +@pytest.mark.parametrize( + "ttnn_function", + (ttnn.lt, ttnn.gt, ttnn.eq, ttnn.le, ttnn.ge, ttnn.ne, ttnn.logical_and, ttnn.logical_or, ttnn.logical_xor), +) +def test_binary_comp_ops(input_shapes, out_dtype, mem_configs, ttnn_function, device): + if is_grayskull(): + pytest.skip("GS does not support fp32/uint32/uint16 data types") + + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) + + cq_id = 0 + mem_cfg = mem_configs + + tt_output_tensor_on_device = ttnn_function( + input_tensor, other_tensor, memory_config=mem_cfg, dtype=out_dtype, queue_id=cq_id + ) + + golden_fn = ttnn.get_golden_function(ttnn_function) + golden_tensor = golden_fn(in_data, other_data) + golden_tensor = golden_tensor.int() + + output_tensor = ttnn.to_torch(tt_output_tensor_on_device) + + are_equal = torch.equal(output_tensor, golden_tensor) + assert are_equal + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([64, 64])), + (torch.Size([2, 32, 32])), + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "mem_configs", + ( + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), + ), +) +@pytest.mark.parametrize("out_dtype", (ttnn.uint32, ttnn.uint16)) +@pytest.mark.parametrize( + "ttnn_function", + (ttnn.lt, ttnn.gt, ttnn.eq, ttnn.le, ttnn.ge, ttnn.ne, ttnn.logical_and, ttnn.logical_or, ttnn.logical_xor), +) +def test_binary_comp_opt_out(input_shapes, out_dtype, mem_configs, ttnn_function, device): + if is_grayskull(): + pytest.skip("GS does not support fp32/uint32/uint16 data types") + + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) + + cq_id = 0 + mem_cfg = mem_configs + _, output_tensor = data_gen_with_range_dtype(input_shapes, -1, 1, device, False, False, out_dtype) + ttnn_function( + input_tensor, other_tensor, memory_config=mem_cfg, dtype=out_dtype, queue_id=cq_id, output_tensor=output_tensor + ) + + golden_fn = ttnn.get_golden_function(ttnn_function) + golden_tensor = golden_fn(in_data, other_data) + golden_tensor = golden_tensor.int() + + output_tensor = ttnn.to_torch(output_tensor) + + are_equal = torch.equal(output_tensor, golden_tensor) + assert are_equal + + +@skip_for_blackhole("Mismatching on BH, see #12349") +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([64, 64])), + (torch.Size([2, 32, 32])), + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "mem_configs", + ( + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), + ), +) +@pytest.mark.parametrize( + "scalar", + {2.3, 15.6, 55.4, 72.5, 120.6}, +) +@pytest.mark.parametrize("out_dtype", (ttnn.uint32, ttnn.uint16)) +@pytest.mark.parametrize( + "ttnn_function", + ( + ttnn.lt, + ttnn.gt, + ttnn.eq, + ttnn.le, + ttnn.ge, + ttnn.ne, + ), +) +def test_binary_comp_ops_scalar(input_shapes, scalar, out_dtype, mem_configs, ttnn_function, device): + if is_grayskull(): + pytest.skip("GS does not support fp32/uint32/uint16 data types") + + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + + cq_id = 0 + mem_cfg = mem_configs + + tt_output_tensor_on_device = ttnn_function( + input_tensor, scalar, memory_config=mem_cfg, dtype=out_dtype, queue_id=cq_id + ) + + golden_fn = ttnn.get_golden_function(ttnn_function) + golden_tensor = golden_fn(in_data, scalar) + golden_tensor = golden_tensor.int() + + output_tensor = ttnn.to_torch(tt_output_tensor_on_device) + + are_equal = torch.equal(output_tensor, golden_tensor) + assert are_equal diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index ff7aa1738bb..11a1013c731 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -10,6 +10,7 @@ #include "ttnn/operations/data_movement/repeat/repeat.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/data_movement/reshape_view/reshape.hpp" +#include "ttnn/operations/copy.hpp" namespace ttnn::operations::binary { @@ -27,6 +28,7 @@ inline Tensor binary_impl( BinaryOpType binary_op_type, const ttnn::Tensor &input_tensor, const float scalar, + const std::optional &dtype = std::nullopt, const std::optional &memory_config = std::nullopt, const std::optional &optional_output_tensor = std::nullopt) { auto output_memory_config = optional_output_tensor.has_value() @@ -60,6 +62,8 @@ inline Tensor binary_impl( } else { TT_THROW("Unsupported operation"); } + if(dtype.has_value()) + output_tensor = ttnn::typecast(queue_id, output_tensor, dtype.value(), std::nullopt, optional_output_tensor); return output_tensor; } @@ -321,7 +325,7 @@ Tensor RelationalBinary::invoke( std::optional activations, std::optional input_tensor_a_activation) { return detail::binary_impl( - DefaultQueueId, binary_op_type, input_tensor_a, scalar, memory_config, optional_output_tensor); + DefaultQueueId, binary_op_type, input_tensor_a, scalar, dtype, memory_config, optional_output_tensor); } template @@ -335,7 +339,7 @@ Tensor RelationalBinary::invoke( std::optional activations, std::optional input_tensor_a_activation) { return detail::binary_impl( - DefaultQueueId, binary_op_type, input_tensor_a, scalar, memory_config, optional_output_tensor); + DefaultQueueId, binary_op_type, input_tensor_a, scalar, dtype, memory_config, optional_output_tensor); } // scalar - tensor combination not available on Pytorch for this op template