diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index c87770bef8..9d16aa9c67 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -388,6 +388,13 @@ def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or"> { }]; } +def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> { + let summary = "Eltwise logical xor."; + let description = [{ + Eltwise logical xor operation. + }]; +} + def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> { let summary = "Eltwise maximum OP."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 01ebae8030..8bff2214c5 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -349,6 +349,13 @@ def TTNN_LogicalOrOp : TTNN_ElementwiseBinaryOp<"logical_or"> { }]; } +def TTNN_LogicalXorOp : TTNN_ElementwiseBinaryOp<"logical_xor"> { + let summary = "Eltwise logical xor."; + let description = [{ + Eltwise logical xor operation. + }]; +} + def TTNN_MaximumOp : TTNN_ElementwiseBinaryOp<"maximum"> { let summary = "Eltwise maximum OP."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index cea35447c6..565cffec6b 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -89,7 +89,8 @@ enum EltwiseOpType: uint32 { Log = 28, Log1p = 29, Expm1 = 30, - Sign = 31 + Sign = 31, + LogicalXor = 32, } union EltwiseOpParams { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 15b1f086b4..6c187f6eae 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -962,6 +962,9 @@ void addLogicalOpConversionPattern(MLIRContext *ctx, ctx); patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); } void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 42d834634e..4b3a705421 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -863,6 +863,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 0582dce371..ea018096b1 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -637,6 +637,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, patterns.add, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern, MultiplyOpConversionPattern, DefaultOpConversionPattern, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 71d793a002..bd237fa6ff 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -309,6 +309,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::LogicalNot; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::LogicalOr; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::LogicalXor; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Multiply; } else if constexpr (std::is_same_v) { @@ -556,6 +558,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto orOp = dyn_cast(op); orOp) { return createOperation(cache, createEltwiseOp(cache, orOp), debugString); } + if (auto xorOp = dyn_cast(op); xorOp) { + return createOperation(cache, createEltwiseOp(cache, xorOp), debugString); + } if (auto multiplyOp = dyn_cast(op); multiplyOp) { return createOperation(cache, createEltwiseOp(cache, multiplyOp), debugString); diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp index 6266dd721f..f60c0b0b30 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp @@ -50,6 +50,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_or); break; } + case ::tt::target::ttnn::EltwiseOpType::LogicalXor: { + runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_xor); + break; + } case ::tt::target::ttnn::EltwiseOpType::Multiply: { runEltwiseBinaryOP(op, tensorPool, ::ttnn::multiply); break; diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir index 8aa5417c30..86ad0fe269 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir @@ -1,39 +1,30 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s -module @jit_eltwise_compare attributes {} { +module @jit_eltwise_logical attributes {} { func.func public @logical_and(%arg0: tensor<13x31xi1>, %arg1: tensor<13x31xi1>) -> tensor<13x31xi1> { - %0 = stablehlo.and %arg0, %arg1 : tensor<13x31xi1> - // CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16> + // CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<13x31xbf16>]] // CHECK: = "ttir.logical_and"(%arg0, %arg1, %[[E]]) - // CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16> + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.and %arg0, %arg1 : tensor<13x31xi1> + // CHECK: return %1 : [[TENSOR]] return %0 : tensor<13x31xi1> - // CHECK: return %1 : tensor<13x31xbf16> } func.func public @logical_or(%arg0: tensor<13x31xi1>, %arg1: tensor<13x31xi1>) -> tensor<13x31xi1> { - %0 = stablehlo.or %arg0, %arg1 : tensor<13x31xi1> - // CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16> + // CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<13x31xbf16>]] // CHECK: = "ttir.logical_or"(%arg0, %arg1, %[[E]]) - // CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16> + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.or %arg0, %arg1 : tensor<13x31xi1> + // CHECK: return %1 : [[TENSOR]] return %0 : tensor<13x31xi1> - // CHECK: return %1 : tensor<13x31xbf16> } -func.func public @logical_not(%arg0: tensor<13x31xi1>) -> tensor<13x31xi1> { - %0 = stablehlo.not %arg0 : tensor<13x31xi1> - // CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16> - // CHECK: = "ttir.logical_not"(%arg0, %[[E]]) - // CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16> + func.func public @logical_xor(%arg0: tensor<13x31xi1>, %arg1: tensor<13x31xi1>) -> tensor<13x31xi1> { + // CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<13x31xbf16>]] + // CHECK: = "ttir.logical_xor"(%arg0, %arg1, %[[E]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.xor %arg0, %arg1 : tensor<13x31xi1> + // CHECK: return %1 : [[TENSOR]] return %0 : tensor<13x31xi1> - // CHECK: return %1 : tensor<13x31xbf16> - } - -func.func public @logical_not_scalar(%arg0: tensor) -> tensor { - %0 = stablehlo.not %arg0 : tensor - // CHECK: %[[E:.*]] = tensor.empty() : tensor<1xbf16> - // CHECK: = "ttir.logical_not"(%arg0, %[[E]]) - // CHECK-SAME: (tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> - return %0 : tensor - // CHECK: return %1 : tensor<1xbf16> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/logical_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/logical_op.mlir new file mode 100644 index 0000000000..d094e59132 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/logical_op.mlir @@ -0,0 +1,22 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_eltwise_logical attributes {} { + func.func public @logical_not(%arg0: tensor<13x31xi1>) -> tensor<13x31xi1> { + // CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16> + // CHECK: = "ttir.logical_not"(%arg0, %[[E]]) + // CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16> + %0 = stablehlo.not %arg0 : tensor<13x31xi1> + // CHECK: return %1 : tensor<13x31xbf16> + return %0 : tensor<13x31xi1> + } + + func.func public @logical_not_scalar(%arg0: tensor) -> tensor { + // CHECK: %[[E:.*]] = tensor.empty() : tensor<1xbf16> + // CHECK: = "ttir.logical_not"(%arg0, %[[E]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.not %arg0 : tensor + // CHECK: return %1 : [[TENSOR]] + return %0 : tensor + + } +} diff --git a/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir new file mode 100644 index 0000000000..f59f49040a --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s + +#any_device = #tt.operand_constraint +module attributes {} { + func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]] + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %{{[0-9]+}} = "ttnn.logical_xor" + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir new file mode 100644 index 0000000000..9cec1b0b5f --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir @@ -0,0 +1,19 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// REQUIRES: https://github.com/tenstorrent/tt-mlir/issues/1149 + +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]] + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %{{[0-9]+}} = "ttnn.logical_xor" + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_logical_xor.mlir b/test/ttmlir/Silicon/TTNN/simple_logical_xor.mlir new file mode 100644 index 0000000000..2bab8dce23 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/simple_logical_xor.mlir @@ -0,0 +1,21 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// REQUIRES: https://github.com/tenstorrent/tt-mlir/issues/1149 + +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +module attributes {} { + func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]] + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %{{[0-9]+}} = "ttnn.logical_xor" + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +}