From 430b0369e9f5bbb9fc051d3c584823c77ccea1bb Mon Sep 17 00:00:00 2001 From: Usman Aziz Date: Tue, 12 Nov 2024 16:55:33 -0500 Subject: [PATCH] Implement conversion for stablehlo.select and add Where Op (#852) * Added conversion for SelectOp to TTIR WhereOp along with end-to-end support. --------- Co-authored-by: Stefan Djordjevic <157365107+sdjordjevicTT@users.noreply.github.com> Co-authored-by: Milan Topalovic <163355844+mtopalovicTT@users.noreply.github.com> Co-authored-by: Nikola Obradovic <132568163+nobradovictt@users.noreply.github.com> Co-authored-by: Filip Bajraktari --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 23 +++++++++++++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 23 +++++++++++++ include/ttmlir/Target/TTNN/program.fbs | 1 + .../StableHLOToTTIRPatterns.cpp | 2 ++ lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 1 + lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 4 +-- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 5 +++ runtime/include/tt/runtime/detail/ttnn.h | 1 + runtime/lib/ttnn/operations/CMakeLists.txt | 2 ++ .../operations/eltwise/ternary/ternary.cpp | 32 +++++++++++++++++++ .../ttnn/operations/eltwise/ternary/ternary.h | 21 ++++++++++++ .../ttnn/operations/eltwise/ternary/utils.cpp | 24 ++++++++++++++ .../ttnn/operations/eltwise/ternary/utils.h | 21 ++++++++++++ runtime/lib/ttnn/program.cpp | 6 ++++ .../Conversion/StableHLOToTTIR/select_op.mlir | 13 ++++++++ test/ttmlir/Dialect/TTNN/simple_where.mlir | 14 ++++++++ .../TTNN/perf_unit/test_perf_where.mlir | 16 ++++++++++ test/ttmlir/Silicon/TTNN/simple_eltwise.mlir | 11 +++++++ 18 files changed, 218 insertions(+), 2 deletions(-) create mode 100644 runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp create mode 100644 runtime/lib/ttnn/operations/eltwise/ternary/ternary.h create mode 100644 runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp create mode 100644 runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.h create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir create mode 100644 test/ttmlir/Dialect/TTNN/simple_where.mlir create mode 100644 test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index cd62b8289..96221e9f3 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -186,6 +186,29 @@ class TTIR_ElementwiseOp traits = []> : let results = (outs Variadic:$results); } +class TTIR_ElementwiseTernaryOp traits = []> : + TTIR_ElementwiseOp { + let summary = "Eltwise ternary op."; + let description = [{ + Eltwise ternary op. + }]; + + let builders = + [ + OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out, "ArrayAttr": $operand_constraints), + [{ + build($_builder, $_state, {out.getType()}, {first, second, third}, out, operand_constraints); + }]> + ]; +} + +def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> { + let summary = "Eltwise where op."; + let description = [{ + Eltwise where operation. + }]; +} + class TTIR_ElementwiseUnaryOp traits = []> : TTIR_ElementwiseOp { let summary = "Eltwise unary op."; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 91cb51cca..d1272acb6 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -146,6 +146,29 @@ class TTNN_ElementwiseBinaryOp traits = []> : ]; } +class TTNN_ElementwiseTernaryOp traits = []> : + TTNN_ElementwiseOp { + let summary = "Eltwise ternary op."; + let description = [{ + Eltwise ternary op. + }]; + + let builders = + [ + OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out), + [{ + build($_builder, $_state, {out.getType()}, {first, second, third}, out); + }]> + ]; +} + +def TTNN_WhereOp : TTNN_ElementwiseTernaryOp<"where"> { + let summary = "Eltwise where."; + let description = [{ + Eltwise where operation. + }]; +} + def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> { let summary = "Eltwise absolute."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index e797ecf92..0bcc06e3c 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -97,6 +97,7 @@ enum EltwiseOpType: uint32 { Remainder = 32, IsFinite = 33, Floor = 34, + Where = 35, } union EltwiseOpParams { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 9d9fbee39..2ded09362 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1027,6 +1027,8 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, ctx); patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); } void addReduceOpsConversionPatterns(MLIRContext *ctx, diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 981045e90..23dce0553 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -922,6 +922,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 0a04c53a8..9e3bda099 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -679,8 +679,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Other ops // patterns.add, - DefaultOpConversionPattern>(typeConverter, - ctx); + DefaultOpConversionPattern, + DefaultOpConversionPattern>(typeConverter, ctx); // CCL ops // diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index a3262a680..7da8bb353 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -381,6 +381,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Expm1; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Remainder; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Where; } else { llvm_unreachable("unhandled EltwiseOp"); } @@ -719,6 +721,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto sinOp = dyn_cast(op); sinOp) { return createOperation(cache, createEltwiseOp(cache, sinOp), debugString); } + if (auto whereOp = dyn_cast(op); whereOp) { + return createOperation(cache, createEltwiseOp(cache, whereOp), debugString); + } llvm_unreachable("unhandled op in emitTTNNOperation"); } diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 4580b290b..9654142e7 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -56,6 +56,7 @@ #include "ttnn/operations/data_movement/permute/permute.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/binary/binary_composite.hpp" +#include "ttnn/operations/eltwise/ternary/where.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/embedding/embedding.hpp" #include "ttnn/operations/matmul/matmul.hpp" diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index db67164ef..3a1e20bfa 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -2,6 +2,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp @@ -15,6 +16,7 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eltwise/ternary/ternary.cpp ${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp diff --git a/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp new file mode 100644 index 000000000..a4b29c4b8 --- /dev/null +++ b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp @@ -0,0 +1,32 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ternary.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/eltwise/ternary/utils.h" +#include "tt/runtime/ttnn/operations/utils.h" + +namespace tt::runtime::ttnn::operations::ternary { + +void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { + if (op->type() != ::tt::target::ttnn::EltwiseOpType::Where) { + throw std::invalid_argument("Unsupported Eltwise Ternary operation"); + } + + ProgramTensorPool &tensorPool = context.getTensorPool(); + + ::ttnn::Tensor *first = nullptr; + ::ttnn::Tensor *second = nullptr; + ::ttnn::Tensor *third = nullptr; + getEltwiseTernaryOPInputTensors(op, tensorPool, &first, &second, &third); + + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + utils::createMemoryConfig(op->out()); + + ::ttnn::Tensor out = + ::ttnn::where(*first, *second, *third, outputMemoryConfig); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::ternary diff --git a/runtime/lib/ttnn/operations/eltwise/ternary/ternary.h b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.h new file mode 100644 index 000000000..1e756ef24 --- /dev/null +++ b/runtime/lib/ttnn/operations/eltwise/ternary/ternary.h @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H +#define TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::ternary { + +inline bool isTernaryOp(const ::tt::target::ttnn::EltwiseOp *op) { + return op->ins()->size() == 3; +} + +void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context); + +} // namespace tt::runtime::ttnn::operations::ternary + +#endif diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp new file mode 100644 index 000000000..d394e928f --- /dev/null +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include "utils.h" +#include "tt/runtime/detail/logger.h" +#include "tt/runtime/detail/workarounds.h" + +namespace tt::runtime::ttnn::operations::ternary { + +void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, + ProgramTensorPool &tensorPool, + ::ttnn::Tensor **first, + ::ttnn::Tensor **second, + ::ttnn::Tensor **third) { + LOG_ASSERT(op->ins()->size() == 3, "Expected 3 inputs"); + *first = &(tensorPool.at(op->ins()->Get(0)->global_id())); + *second = &(tensorPool.at(op->ins()->Get(1)->global_id())); + *third = &(tensorPool.at(op->ins()->Get(2)->global_id())); + DEBUG_ASSERT((*first)->is_allocated()); + DEBUG_ASSERT((*second)->is_allocated()); + DEBUG_ASSERT((*third)->is_allocated()); +} + +} // namespace tt::runtime::ttnn::operations::ternary diff --git a/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.h b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.h new file mode 100644 index 000000000..774cbdc3e --- /dev/null +++ b/runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.h @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_ELTWISE_TERNARY_UTILS_H +#define TTNN_RUNTIME_ELTWISE_TERNARY_UTILS_H + +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::ternary { +void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, + ProgramTensorPool &tensorPool, + ::ttnn::Tensor **first, + ::ttnn::Tensor **second, + ::ttnn::Tensor **third); + +} // namespace tt::runtime::ttnn::operations::ternary + +#endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index af1b28d99..00db959de 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -13,6 +13,7 @@ #include "operations/deletion/dealloc.h" #include "operations/eltwise/binary/binary.h" #include "operations/eltwise/binary/binary_composite.h" +#include "operations/eltwise/ternary/ternary.h" #include "operations/eltwise/unary/unary.h" #include "operations/eltwise/unary/unary_composite.h" #include "operations/embedding/embedding.h" @@ -73,6 +74,8 @@ void ProgramExecutor::runEltwiseOperation( return operations::binary::run(op, context); }; + auto runTernaryOp = [&]() { return operations::ternary::run(op, context); }; + if (operations::unary::isUnaryOp(op)) { return runUnaryOp(); } @@ -80,6 +83,9 @@ void ProgramExecutor::runEltwiseOperation( if (operations::binary::isBinaryOp(op)) { return runBinaryOp(); } + if (operations::ternary::isTernaryOp(op)) { + return runTernaryOp(); + } throw std::invalid_argument("Unsupported Eltwise operation"); } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir new file mode 100644 index 000000000..458879081 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir @@ -0,0 +1,13 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_eltwise_select attributes {} { + func.func public @test_select(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { + %0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xi1> + %1 = stablehlo.select %0, %arg0, %arg1 : (tensor<13x37xi1>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + // CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() + // CHECK: %[[VAL1:[0-9]+]] = "ttir.eq" + // CHECK: %[[SELECT:[0-9]+]] = "ttir.where"(%[[VAL1:[0-9]+]], %arg0, %arg1, %[[EMPTY:[0-9]+]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + return %1 : tensor<13x37xf32> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_where.mlir b/test/ttmlir/Dialect/TTNN/simple_where.mlir new file mode 100644 index 000000000..a535a4fd9 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_where.mlir @@ -0,0 +1,14 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module @jit_eltwise_where { + func.func public @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { + %0 = tensor.empty() : tensor<13x37xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + %2 = tensor.empty() : tensor<13x37xf32> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} + // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) + // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) + return %3 : tensor<13x37xf32> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir new file mode 100644 index 000000000..3bed0528c --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { + %0 = tensor.empty() : tensor<13x37xbf16> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %2 = tensor.empty() : tensor<13x37xf32> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} + // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) + // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) + return %3 : tensor<13x37xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index 7a07cc15a..229830f48 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -258,3 +258,14 @@ func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> { return %0 : tensor<1xi32> // CHECK: return [[VAL]] : tensor<1xi32, {{.*}}> } + +func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { + %0 = tensor.empty() : tensor<13x37xbf16> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %2 = tensor.empty() : tensor<13x37xf32> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} + // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) + // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) + return %3 : tensor<13x37xf32> +}