Skip to content

Commit

Permalink
#14730: Support unequal ranked inputs for eltwise binary (#14803)
Browse files Browse the repository at this point in the history
### Ticket
Link to Github Issue #14730 
also #14731
### Problem description
Need support for unequal ranked inputs in eltwise binary

### What's changed
Added support using ttnn.reshape when inputs are of different ranks

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/11736842353

https://github.com/tenstorrent/tt-metal/actions/runs/11794021567/attempts/2
- [ ] Nightly FD
https://github.com/tenstorrent/tt-metal/actions/runs/11736844358

https://github.com/tenstorrent/tt-metal/actions/runs/11794025766/job/32854516936
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
KalaivaniMCW authored Nov 12, 2024
1 parent 24a6dbf commit 274f58a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 3 deletions.
27 changes: 27 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,33 @@
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize(
"shapes",
[
[[1, 71, 7, 7], [7, 7]],
[[920, 1, 256], [256]],
],
)
def test_unequal_ranks(device, shapes):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand(shapes[0], dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor_a + torch_input_tensor_b

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988


@pytest.mark.parametrize(
"shapes",
[
Expand Down
32 changes: 32 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,38 @@
from torch.nn import functional as F


@pytest.mark.parametrize(
"shapes",
[
[[4, 12, 64, 64], [12, 1, 1]],
[[4, 16, 64, 64], [16, 1, 1]],
[[64, 3, 64, 64], [3, 1, 1]],
[[64, 4, 64, 64], [4, 1, 1]],
[[16, 6, 64, 64], [6, 1, 1]],
[[16, 8, 64, 64], [8, 1, 1]],
[[1, 1], [1, 1, 32]],
],
)
def test_unequal_ranks(device, shapes):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand(shapes[0], dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor_a * torch_input_tensor_b

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

output_tensor = ttnn.mul(input_tensor_a, input_tensor_b, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

assert ttnn.pearson_correlation_coefficient(torch_output_tensor, output_tensor) >= 0.99988


# fmt: off
@pytest.mark.parametrize("scalar", [3.0])
# fmt: on
Expand Down
32 changes: 29 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttnn/device_operation.hpp"
#include "ttnn/operations/data_movement/repeat/repeat.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"

namespace ttnn::operations::binary {

Expand Down Expand Up @@ -99,10 +100,34 @@ inline Tensor binary_impl(
}

template <BinaryOpType binary_op_type>
auto preprocess_inputs(const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg) {
auto preprocess_inputs(const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg, const std::optional<Tensor> &optional_output_tensor) {
Tensor input_tensor_a = input_tensor_a_arg;
Tensor input_tensor_b = input_tensor_b_arg;

auto rank_a = input_tensor_a.get_shape().rank();
auto rank_b = input_tensor_b.get_shape().rank();

if(rank_a != rank_b){

auto max_rank = std::max(rank_a, rank_b);
auto min_rank = std::min(rank_a, rank_b);

if(optional_output_tensor.has_value()) {
auto opt_rank = optional_output_tensor.value().get_shape().rank();
TT_FATAL( max_rank == opt_rank,
"Output Tensor rank {} doesn't match input tensor rank {}.", opt_rank, max_rank );
}

std::vector<int32_t> shape_vector(max_rank, 1);
auto& reshaped_tensor = (rank_a > rank_b) ? input_tensor_b : input_tensor_a;
auto s_b = reshaped_tensor.get_shape();
for(int i=0; i < min_rank; ++i){
shape_vector[(max_rank - min_rank) + i] = s_b[i];
}
reshaped_tensor = ttnn::reshape(reshaped_tensor, shape_vector);

}

// TODO: #7731 (Remove calls to repeat )
auto repeat_smaller = [](const auto &first, auto &second) {
const auto first_shape = first.get_shape();
Expand Down Expand Up @@ -149,8 +174,9 @@ Tensor BinaryOperation<binary_op_type>::invoke(
std::optional<Tensor> optional_output_tensor,
std::optional<unary::FusedActivations> activations,
std::optional<unary::UnaryWithParam> input_tensor_a_activation) {

auto [input_tensor_a, input_tensor_b] =
detail::preprocess_inputs<binary_op_type>(input_tensor_a_arg, input_tensor_b_arg);
detail::preprocess_inputs<binary_op_type>(input_tensor_a_arg, input_tensor_b_arg, optional_output_tensor);

return ttnn::prim::binary(
queue_id,
Expand Down Expand Up @@ -245,7 +271,7 @@ Tensor RelationalBinary<binary_op_type>::invoke(
}

auto [input_tensor_a, input_tensor_b] =
detail::preprocess_inputs<binary_op_type>(input_tensor_a_arg, input_tensor_b_arg);
detail::preprocess_inputs<binary_op_type>(input_tensor_a_arg, input_tensor_b_arg, optional_output_tensor);

auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config());
DataType dtype = output_dtype.value_or(input_tensor_a.get_dtype());
Expand Down

0 comments on commit 274f58a

Please sign in to comment.