diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 6bb03a6974..63040435e6 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -210,6 +210,13 @@ def TTIR_CosOp: TTIR_ElementwiseUnaryOp<"cos"> { }]; } +def TTIR_GeluOp: TTIR_ElementwiseUnaryOp<"gelu"> { + let summary = "Eltwise GELU op."; + let description = [{ + Eltwise GELU operation. + }]; +} + def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> { let summary = "Eltwise logical not op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index f878e7f21d..225fcfb759 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -181,6 +181,13 @@ def TTNN_ExpOp : TTNN_ElementwiseUnaryOp<"exp"> { }]; } +def TTNN_GeluOp: TTNN_ElementwiseUnaryOp<"gelu"> { + let summary = "Eltwise GELU."; + let description = [{ + Eltwise GELU operation. + }]; +} + def TTNN_LogicalNotOp: TTNN_ElementwiseUnaryOp<"logical_not"> { let summary = "Eltwise logical not op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 56c80410d0..8d1d90dce3 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -85,7 +85,8 @@ enum EltwiseOpType: uint32 { Minimum = 24, Ceil = 25, Sin = 26, - Cos = 27 + Cos = 27, + Gelu = 28 } table EltwiseOp { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 6391d48e9b..4396b4707c 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -826,6 +826,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 4655bc0c33..478309d6ad 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -619,6 +619,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 39352d1bb9..eaee949c39 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -348,6 +348,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Cos; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Sin; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Gelu; } else { llvm_unreachable("unhandled EltwiseOp"); } @@ -662,6 +664,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto sinOp = dyn_cast(op); sinOp) { return createOperation(cache, createEltwiseOp(cache, sinOp), debugString); } + if (auto geluOp = dyn_cast(op); geluOp) { + return createOperation(cache, createEltwiseOp(cache, geluOp), debugString); + } llvm_unreachable("unhandled op in emitTTNNOperation"); } diff --git a/runtime/lib/ttnn/operations/eltwise/unary.cpp b/runtime/lib/ttnn/operations/eltwise/unary.cpp index a4aca2c482..54adaaeb82 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary.cpp @@ -92,6 +92,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseUnaryOP(op, tensorPool, ::ttnn::cos); break; } + case ::tt::target::ttnn::EltwiseOpType::Gelu: { + runEltwiseUnaryWithFastAndApproximateModeOP(op, tensorPool, ::ttnn::gelu); + break; + } case ::tt::target::ttnn::EltwiseOpType::LogicalNot: { runEltwiseUnaryOP(op, tensorPool, ::ttnn::logical_not); break; diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir new file mode 100644 index 0000000000..c0bda5113b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.gelu"[[C:.*]] + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index b8f387d274..1fbe48c1ed 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -183,3 +183,11 @@ func.func @typecast(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } + +func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.gelu"[[C:.*]] + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + return %1 : tensor<64x128xf32> +} \ No newline at end of file