diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 96221e9f3..3ffe2a977 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -260,6 +260,13 @@ def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor"> { }]; } +def TTIR_GeluOp: TTIR_ElementwiseUnaryOp<"gelu"> { + let summary = "Eltwise GELU op."; + let description = [{ + Eltwise GELU operation. + }]; +} + def TTIR_IsFiniteOp: TTIR_ElementwiseUnaryOp<"isfinite"> { let summary = "Eltwise isfinite op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 11baafbe7..7601214b2 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -223,6 +223,13 @@ def TTNN_FloorOp: TTNN_ElementwiseUnaryOp<"floor"> { }]; } +def TTNN_GeluOp: TTNN_ElementwiseUnaryOp<"gelu"> { + let summary = "Eltwise GELU."; + let description = [{ + Eltwise GELU operation. + }]; +} + def TTNN_IsFiniteOp: TTNN_ElementwiseUnaryOp<"isfinite"> { let summary = "Eltwise isfinite op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index e501cf3a8..fa1cd424f 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -98,6 +98,7 @@ enum EltwiseOpType: uint32 { IsFinite = 33, Floor = 34, Where = 35, + Gelu = 36, } union EltwiseOpParams { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 23dce0553..f2501186f 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -908,6 +908,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 5a68e4875..4c6fadd83 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -651,6 +651,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 118c19bf9..fb81ab646 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -383,6 +383,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Remainder; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Where; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Gelu; } else { llvm_unreachable("unhandled EltwiseOp"); } @@ -725,6 +727,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto whereOp = dyn_cast(op); whereOp) { return createOperation(cache, createEltwiseOp(cache, whereOp), debugString); } + if (auto geluOp = dyn_cast(op); geluOp) { + return createOperation(cache, createEltwiseOp(cache, geluOp), debugString); + } llvm_unreachable("unhandled op in emitTTNNOperation"); } diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp index 404349404..8ec84a69f 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp @@ -66,6 +66,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseUnaryOP(op, tensorPool, ::ttnn::floor); break; } + case ::tt::target::ttnn::EltwiseOpType::Gelu: { + runEltwiseUnaryWithFastAndApproximateModeOP(op, tensorPool, ::ttnn::gelu); + break; + } case ::tt::target::ttnn::EltwiseOpType::IsFinite: { runEltwiseUnaryOP(op, tensorPool, ::ttnn::isfinite); break; diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir new file mode 100644 index 000000000..0fe3e9c3b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x128xf32, + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: "ttnn.gelu" + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xf32, + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gelu.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gelu.mlir new file mode 100644 index 000000000..628bb5c37 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gelu.mlir @@ -0,0 +1,17 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x128xf32, + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: "ttnn.gelu" + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xf32, + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index 229830f48..1ff209b26 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -269,3 +269,15 @@ func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> ten // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) return %3 : tensor<13x37xf32> } + +func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: "ttnn.empty" + // CHECK-SAME: tensor<64x128xf32, + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: "ttnn.gelu" + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xf32, + // CHECK-SAME: tensor<64x128xf32, + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +}