Skip to content

Commit

Permalink
Add integer support for eltwise ops (#14953)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #13374

### What's changed

-  Updated the comparison ops to return datatype as integer.


Ops Added Int Support:

- ELEMWISE_UNARY_EQ
- ELEMWISE_UNARY_GE
- ELEMWISE_UNARY_GT
- ELEMWISE_UNARY_LE
- ELEMWISE_UNARY_LT
- ELEMWISE_UNARY_NE
- ELEMWISE_BINARY_GE
- ELEMWISE_BINARY_GT
- ELEMWISE_BINARY_EQ
- ELEMWISE_BINARY_LE
- ELEMWISE_BINARY_LOGICALAND
- ELEMWISE_BINARY_LOGICALOR
- ELEMWISE_BINARY_LOGICALXOR
- ELEMWISE_BINARY_LT
- ELEMWISE_BINARY_NE

### Checklist
- [ ] [All Post commit
CI](https://github.com/tenstorrent/tt-metal/actions/runs/11796923243)
  • Loading branch information
mouliraj-mcw authored Nov 17, 2024
1 parent 42a30d7 commit 7a99cc8
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 2 deletions.
163 changes: 163 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_comp_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
data_gen_with_range_dtype,
)
from models.utility_functions import is_grayskull, skip_for_blackhole


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([64, 64])),
(torch.Size([2, 32, 32])),
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"mem_configs",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
)
@pytest.mark.parametrize("out_dtype", (ttnn.uint32, ttnn.uint16))
@pytest.mark.parametrize(
"ttnn_function",
(ttnn.lt, ttnn.gt, ttnn.eq, ttnn.le, ttnn.ge, ttnn.ne, ttnn.logical_and, ttnn.logical_or, ttnn.logical_xor),
)
def test_binary_comp_ops(input_shapes, out_dtype, mem_configs, ttnn_function, device):
if is_grayskull():
pytest.skip("GS does not support fp32/uint32/uint16 data types")

in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)

cq_id = 0
mem_cfg = mem_configs

tt_output_tensor_on_device = ttnn_function(
input_tensor, other_tensor, memory_config=mem_cfg, dtype=out_dtype, queue_id=cq_id
)

golden_fn = ttnn.get_golden_function(ttnn_function)
golden_tensor = golden_fn(in_data, other_data)
golden_tensor = golden_tensor.int()

output_tensor = ttnn.to_torch(tt_output_tensor_on_device)

are_equal = torch.equal(output_tensor, golden_tensor)
assert are_equal


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([64, 64])),
(torch.Size([2, 32, 32])),
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"mem_configs",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
)
@pytest.mark.parametrize("out_dtype", (ttnn.uint32, ttnn.uint16))
@pytest.mark.parametrize(
"ttnn_function",
(ttnn.lt, ttnn.gt, ttnn.eq, ttnn.le, ttnn.ge, ttnn.ne, ttnn.logical_and, ttnn.logical_or, ttnn.logical_xor),
)
def test_binary_comp_opt_out(input_shapes, out_dtype, mem_configs, ttnn_function, device):
if is_grayskull():
pytest.skip("GS does not support fp32/uint32/uint16 data types")

in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)

cq_id = 0
mem_cfg = mem_configs
_, output_tensor = data_gen_with_range_dtype(input_shapes, -1, 1, device, False, False, out_dtype)
ttnn_function(
input_tensor, other_tensor, memory_config=mem_cfg, dtype=out_dtype, queue_id=cq_id, output_tensor=output_tensor
)

golden_fn = ttnn.get_golden_function(ttnn_function)
golden_tensor = golden_fn(in_data, other_data)
golden_tensor = golden_tensor.int()

output_tensor = ttnn.to_torch(output_tensor)

are_equal = torch.equal(output_tensor, golden_tensor)
assert are_equal


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([64, 64])),
(torch.Size([2, 32, 32])),
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"mem_configs",
(
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM),
ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1),
),
)
@pytest.mark.parametrize(
"scalar",
{2.3, 15.6, 55.4, 72.5, 120.6},
)
@pytest.mark.parametrize("out_dtype", (ttnn.uint32, ttnn.uint16))
@pytest.mark.parametrize(
"ttnn_function",
(
ttnn.lt,
ttnn.gt,
ttnn.eq,
ttnn.le,
ttnn.ge,
ttnn.ne,
),
)
def test_binary_comp_ops_scalar(input_shapes, scalar, out_dtype, mem_configs, ttnn_function, device):
if is_grayskull():
pytest.skip("GS does not support fp32/uint32/uint16 data types")

in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)

cq_id = 0
mem_cfg = mem_configs

tt_output_tensor_on_device = ttnn_function(
input_tensor, scalar, memory_config=mem_cfg, dtype=out_dtype, queue_id=cq_id
)

golden_fn = ttnn.get_golden_function(ttnn_function)
golden_tensor = golden_fn(in_data, scalar)
golden_tensor = golden_tensor.int()

output_tensor = ttnn.to_torch(tt_output_tensor_on_device)

are_equal = torch.equal(output_tensor, golden_tensor)
assert are_equal
8 changes: 6 additions & 2 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ttnn/operations/data_movement/repeat/repeat.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"
#include "ttnn/operations/copy.hpp"

namespace ttnn::operations::binary {

Expand All @@ -27,6 +28,7 @@ inline Tensor binary_impl(
BinaryOpType binary_op_type,
const ttnn::Tensor &input_tensor,
const float scalar,
const std::optional<const DataType> &dtype = std::nullopt,
const std::optional<ttnn::MemoryConfig> &memory_config = std::nullopt,
const std::optional<Tensor> &optional_output_tensor = std::nullopt) {
auto output_memory_config = optional_output_tensor.has_value()
Expand Down Expand Up @@ -60,6 +62,8 @@ inline Tensor binary_impl(
} else {
TT_THROW("Unsupported operation");
}
if(dtype.has_value())
output_tensor = ttnn::typecast(queue_id, output_tensor, dtype.value(), std::nullopt, optional_output_tensor);
return output_tensor;
}

Expand Down Expand Up @@ -321,7 +325,7 @@ Tensor RelationalBinary<binary_op_type>::invoke(
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
return detail::binary_impl(
DefaultQueueId, binary_op_type, input_tensor_a, scalar, memory_config, optional_output_tensor);
DefaultQueueId, binary_op_type, input_tensor_a, scalar, dtype, memory_config, optional_output_tensor);
}

template <BinaryOpType binary_op_type>
Expand All @@ -335,7 +339,7 @@ Tensor RelationalBinary<binary_op_type>::invoke(
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
return detail::binary_impl(
DefaultQueueId, binary_op_type, input_tensor_a, scalar, memory_config, optional_output_tensor);
DefaultQueueId, binary_op_type, input_tensor_a, scalar, dtype, memory_config, optional_output_tensor);
}
// scalar - tensor combination not available on Pytorch for this op
template <BinaryOpType binary_op_type>
Expand Down

0 comments on commit 7a99cc8

Please sign in to comment.