diff --git a/test/ttmlir/Dialect/TTNN/simple_where.mlir b/test/ttmlir/Dialect/TTNN/simple_where.mlir new file mode 100644 index 0000000000..9df12c77f5 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_where.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device_tile = #tt.operand_constraint +module @jit_eltwise_where { + func.func public @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { + %0 = tensor.empty() : tensor<13x37xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + %2 = tensor.empty() : tensor<13x37xf32> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.where"[[C:.*]] + return %3 : tensor<13x37xf32> + } +}