Skip to content

Commit

Permalink
Added GELU Op support (#1125)
Browse files Browse the repository at this point in the history
  • Loading branch information
azecevicTT authored Nov 13, 2024
1 parent cb1e6fc commit 9925638
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,13 @@ def TTIR_FloorOp: TTIR_ElementwiseUnaryOp<"floor"> {
}];
}

def TTIR_GeluOp: TTIR_ElementwiseUnaryOp<"gelu"> {
let summary = "Eltwise GELU op.";
let description = [{
Eltwise GELU operation.
}];
}

def TTIR_IsFiniteOp: TTIR_ElementwiseUnaryOp<"isfinite"> {
let summary = "Eltwise isfinite op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ def TTNN_FloorOp: TTNN_ElementwiseUnaryOp<"floor"> {
}];
}

def TTNN_GeluOp: TTNN_ElementwiseUnaryOp<"gelu"> {
let summary = "Eltwise GELU.";
let description = [{
Eltwise GELU operation.
}];
}

def TTNN_IsFiniteOp: TTNN_ElementwiseUnaryOp<"isfinite"> {
let summary = "Eltwise isfinite op.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ enum EltwiseOpType: uint32 {
IsFinite = 33,
Floor = 34,
Where = 35,
Gelu = 36,
}

union EltwiseOpParams {
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::MinimumOp, ttnn::MinimumOp>,
ElementwiseOpConversionPattern<ttir::NegOp, ttnn::NegOp>,
ElementwiseOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ElementwiseOpConversionPattern<ttir::GeluOp, ttnn::GeluOp>,
ElementwiseOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
ElementwiseOpConversionPattern<ttir::RsqrtOp, ttnn::RsqrtOp>,
ElementwiseOpConversionPattern<ttir::SignOp, ttnn::SignOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
DefaultOpConversionPattern<ttnn::NegOp>,
DefaultOpConversionPattern<ttnn::ReluOp>,
DefaultOpConversionPattern<ttnn::GeluOp>,
DefaultOpConversionPattern<ttnn::SqrtOp>,
DefaultOpConversionPattern<ttnn::RsqrtOp>,
DefaultOpConversionPattern<ttnn::SignOp>,
Expand Down
5 changes: 5 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Remainder;
} else if constexpr (std::is_same_v<EltwiseOp, WhereOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Where;
} else if constexpr (std::is_same_v<EltwiseOp, GeluOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Gelu;
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -725,6 +727,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto whereOp = dyn_cast<WhereOp>(op); whereOp) {
return createOperation(cache, createEltwiseOp(cache, whereOp), debugString);
}
if (auto geluOp = dyn_cast<GeluOp>(op); geluOp) {
return createOperation(cache, createEltwiseOp(cache, geluOp), debugString);
}

llvm_unreachable("unhandled op in emitTTNNOperation");
}
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::floor);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Gelu: {
runEltwiseUnaryWithFastAndApproximateModeOP(op, tensorPool, ::ttnn::gelu);
break;
}
case ::tt::target::ttnn::EltwiseOpType::IsFinite: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::isfinite);
break;
Expand Down
15 changes: 15 additions & 0 deletions test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: "ttnn.empty"
// CHECK-SAME: tensor<64x128xf32,
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: "ttnn.gelu"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xf32,
%1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}
17 changes: 17 additions & 0 deletions test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gelu.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>

func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: "ttnn.empty"
// CHECK-SAME: tensor<64x128xf32,
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: "ttnn.gelu"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xf32,
%1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
12 changes: 12 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,15 @@ func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> ten
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
return %3 : tensor<13x37xf32>
}

func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: "ttnn.empty"
// CHECK-SAME: tensor<64x128xf32,
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: "ttnn.gelu"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xf32,
%1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}

0 comments on commit 9925638

Please sign in to comment.