diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 3ffe2a977..1bcdf3f7a 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -451,6 +451,13 @@ def TTIR_LogicalOrOp : TTIR_ElementwiseBinaryOp<"logical_or"> { }]; } +def TTIR_LogicalXorOp : TTIR_ElementwiseBinaryOp<"logical_xor"> { + let summary = "Eltwise logical xor."; + let description = [{ + Eltwise logical xor operation. + }]; +} + def TTIR_MaximumOp : TTIR_ElementwiseBinaryOp<"maximum"> { let summary = "Eltwise maximum OP."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 7601214b2..ccb037d58 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -395,6 +395,13 @@ def TTNN_LogicalOrOp : TTNN_ElementwiseBinaryOp<"logical_or"> { }]; } +def TTNN_LogicalXorOp : TTNN_ElementwiseBinaryOp<"logical_xor"> { + let summary = "Eltwise logical xor."; + let description = [{ + Eltwise logical xor operation. + }]; +} + def TTNN_MaximumOp : TTNN_ElementwiseBinaryOp<"maximum"> { let summary = "Eltwise maximum OP."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index fa1cd424f..2bec46c02 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -99,6 +99,7 @@ enum EltwiseOpType: uint32 { Floor = 34, Where = 35, Gelu = 36, + LogicalXor = 37, } union EltwiseOpParams { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 2ded09362..b51a8e415 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1114,6 +1114,9 @@ void addLogicalOpConversionPattern(MLIRContext *ctx, ctx); patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); } void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index f2501186f..d3a4c373b 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -897,6 +897,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 4c6fadd83..b7d62e2a8 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -670,6 +670,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, patterns.add, EltwiseBinaryOpConversionPattern, EltwiseBinaryOpConversionPattern, + EltwiseBinaryOpConversionPattern, EltwiseBinaryOpConversionPattern, EltwiseBinaryOpConversionPattern, DefaultOpConversionPattern, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index fb81ab646..b7c3c56b6 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -329,6 +329,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::LogicalNot; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::LogicalOr; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::LogicalXor; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Multiply; } else if constexpr (std::is_same_v) { @@ -590,6 +592,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto orOp = dyn_cast(op); orOp) { return createOperation(cache, createEltwiseOp(cache, orOp), debugString); } + if (auto xorOp = dyn_cast(op); xorOp) { + return createOperation(cache, createEltwiseOp(cache, xorOp), debugString); + } if (auto multiplyOp = dyn_cast(op); multiplyOp) { return createOperation(cache, createEltwiseOp(cache, multiplyOp), debugString); diff --git a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp index 6266dd721..f60c0b0b3 100644 --- a/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/binary/binary.cpp @@ -50,6 +50,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_or); break; } + case ::tt::target::ttnn::EltwiseOpType::LogicalXor: { + runEltwiseBinaryOP(op, tensorPool, ::ttnn::logical_xor); + break; + } case ::tt::target::ttnn::EltwiseOpType::Multiply: { runEltwiseBinaryOP(op, tensorPool, ::ttnn::multiply); break; diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir index 8aa5417c3..f82d71877 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/binary/logical_op.mlir @@ -1,39 +1,30 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s -module @jit_eltwise_compare attributes {} { - func.func public @logical_and(%arg0: tensor<13x31xi1>, %arg1: tensor<13x31xi1>) -> tensor<13x31xi1> { - %0 = stablehlo.and %arg0, %arg1 : tensor<13x31xi1> - // CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16> +module @jit_eltwise_logical attributes {} { + func.func public @logical_and(%arg0: tensor<32x32xi1>, %arg1: tensor<32x32xi1>) -> tensor<32x32xi1> { + // CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<32x32xbf16>]] // CHECK: = "ttir.logical_and"(%arg0, %arg1, %[[E]]) - // CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16> - return %0 : tensor<13x31xi1> - // CHECK: return %1 : tensor<13x31xbf16> + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.and %arg0, %arg1 : tensor<32x32xi1> + // CHECK: return %1 : [[TENSOR]] + return %0 : tensor<32x32xi1> } - func.func public @logical_or(%arg0: tensor<13x31xi1>, %arg1: tensor<13x31xi1>) -> tensor<13x31xi1> { - %0 = stablehlo.or %arg0, %arg1 : tensor<13x31xi1> - // CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16> + func.func public @logical_or(%arg0: tensor<32x32xi1>, %arg1: tensor<32x32xi1>) -> tensor<32x32xi1> { + // CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<32x32xbf16>]] // CHECK: = "ttir.logical_or"(%arg0, %arg1, %[[E]]) - // CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16> - return %0 : tensor<13x31xi1> - // CHECK: return %1 : tensor<13x31xbf16> + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.or %arg0, %arg1 : tensor<32x32xi1> + // CHECK: return %1 : [[TENSOR]] + return %0 : tensor<32x32xi1> } -func.func public @logical_not(%arg0: tensor<13x31xi1>) -> tensor<13x31xi1> { - %0 = stablehlo.not %arg0 : tensor<13x31xi1> - // CHECK: %[[E:.*]] = tensor.empty() : tensor<13x31xbf16> - // CHECK: = "ttir.logical_not"(%arg0, %[[E]]) - // CHECK-SAME: (tensor<13x31xbf16>, tensor<13x31xbf16>) -> tensor<13x31xbf16> - return %0 : tensor<13x31xi1> - // CHECK: return %1 : tensor<13x31xbf16> - } - -func.func public @logical_not_scalar(%arg0: tensor) -> tensor { - %0 = stablehlo.not %arg0 : tensor - // CHECK: %[[E:.*]] = tensor.empty() : tensor<1xbf16> - // CHECK: = "ttir.logical_not"(%arg0, %[[E]]) - // CHECK-SAME: (tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> - return %0 : tensor - // CHECK: return %1 : tensor<1xbf16> + func.func public @logical_xor(%arg0: tensor<32x32xi1>, %arg1: tensor<32x32xi1>) -> tensor<32x32xi1> { + // CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<32x32xbf16>]] + // CHECK: = "ttir.logical_xor"(%arg0, %arg1, %[[E]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.xor %arg0, %arg1 : tensor<32x32xi1> + // CHECK: return %1 : [[TENSOR]] + return %0 : tensor<32x32xi1> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/logical_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/logical_op.mlir new file mode 100644 index 000000000..ef63726a6 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/logical_op.mlir @@ -0,0 +1,21 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_eltwise_logical attributes {} { + func.func public @logical_not(%arg0: tensor<32x32xi1>) -> tensor<32x32xi1> { + // CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<32x32xbf16>]] + // CHECK: = "ttir.logical_not"(%arg0, %[[E]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.not %arg0 : tensor<32x32xi1> + // CHECK: return %1 : [[TENSOR]] + return %0 : tensor<32x32xi1> + } + + func.func public @logical_not_scalar(%arg0: tensor) -> tensor { + // CHECK: %[[E:.*]] = tensor.empty() : [[TENSOR:tensor<1xbf16>]] + // CHECK: = "ttir.logical_not"(%arg0, %[[E]]) + // CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]] + %0 = stablehlo.not %arg0 : tensor + // CHECK: return %1 : [[TENSOR]] + return %0 : tensor + } +} diff --git a/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir new file mode 100644 index 000000000..f59f49040 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir @@ -0,0 +1,16 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s + +#any_device = #tt.operand_constraint +module attributes {} { + func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]] + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %{{[0-9]+}} = "ttnn.logical_xor" + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir new file mode 100644 index 000000000..c47a34cee --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir @@ -0,0 +1,18 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn + +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]] + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %{{[0-9]+}} = "ttnn.logical_xor" + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_logical.mlir b/test/ttmlir/Silicon/TTNN/simple_logical.mlir index fbb2d1819..e5d68f5ec 100644 --- a/test/ttmlir/Silicon/TTNN/simple_logical.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_logical.mlir @@ -37,4 +37,16 @@ module attributes {} { // CHECK-SAME: tensor<64x128xf32, return %1 : tensor<64x128xf32> } + + func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { + // CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]] + %0 = tensor.empty() : tensor<64x128xbf16> + // CHECK: %{{[0-9]+}} = "ttnn.logical_xor" + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: [[TENSOR]] + // CHECK-SAME: -> [[TENSOR]] + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + return %1 : tensor<64x128xbf16> + } }