Skip to content

Commit

Permalink
Add support for sign op (#1086)
Browse files Browse the repository at this point in the history
* Adding sign op

* More detailed testing

* Added perf test

* Formatting changes
  • Loading branch information
ajakovljevicTT authored Nov 4, 2024
1 parent d73456b commit 25191b6
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 1 deletion.
12 changes: 12 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ def TTIR_CosOp: TTIR_ElementwiseUnaryOp<"cos"> {
}];
}

def TTIR_SignOp: TTIR_ElementwiseUnaryOp<"sign"> {
let summary = "Eltwise sign operation.";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
tensor.

Example:
%a: [[3, -2, 0], [1, -4, 4]]
"ttir.sign"(%a, %out) -> %out: [[1, -1, 0], [1, -1, 1]]
}];
}

def TTIR_LogicalNotOp: TTIR_ElementwiseUnaryOp<"logical_not"> {
let summary = "Eltwise logical not op.";
let description = [{
Expand Down
12 changes: 12 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,18 @@ def TTNN_CeilOp : TTNN_ElementwiseUnaryOp<"ceil"> {
}];
}

def TTNN_SignOp: TTNN_ElementwiseUnaryOp<"sign"> {
let summary = "Eltwise sign operation.";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
tensor.

Example:
%a: [[3, -2, 0], [1, -4, 4]]
"ttnn.sign"(%a, %out) -> %out: [[1, -1, 0], [1, -1, 1]]
}];
}

def TTNN_CosOp : TTNN_ElementwiseUnaryOp<"cos"> {
let summary = "Eltwise cosine.";
let description = [{
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ enum EltwiseOpType: uint32 {
Cos = 27,
Log = 28,
Log1p = 29,
Expm1 = 30
Expm1 = 30,
Sign = 31
}

union EltwiseOpParams {
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,8 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
mlir::stablehlo::Log1pOp, mlir::tt::ttir::Log1pOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::Expm1Op, mlir::tt::ttir::Expm1Op>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SignOp, mlir::tt::ttir::SignOp>>(typeConverter, ctx);
}

void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ElementwiseOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
ElementwiseOpConversionPattern<ttir::RsqrtOp, ttnn::RsqrtOp>,
ElementwiseOpConversionPattern<ttir::SignOp, ttnn::SignOp>,
ElementwiseOpConversionPattern<ttir::SigmoidOp, ttnn::SigmoidOp>,
ElementwiseOpConversionPattern<ttir::Log1pOp, ttnn::Log1pOp>,
ElementwiseOpConversionPattern<ttir::ReciprocalOp, ttnn::ReciprocalOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::ReluOp>,
DefaultOpConversionPattern<ttnn::SqrtOp>,
DefaultOpConversionPattern<ttnn::RsqrtOp>,
DefaultOpConversionPattern<ttnn::SignOp>,
DefaultOpConversionPattern<ttnn::SigmoidOp>,
DefaultOpConversionPattern<ttnn::Log1pOp>,
DefaultOpConversionPattern<ttnn::ReciprocalOp>,
Expand Down
5 changes: 5 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Sqrt;
} else if constexpr (std::is_same_v<EltwiseOp, RsqrtOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Rsqrt;
} else if constexpr (std::is_same_v<EltwiseOp, SignOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Sign;
} else if constexpr (std::is_same_v<EltwiseOp, ReciprocalOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Reciprocal;
} else if constexpr (std::is_same_v<EltwiseOp, DivOp>) {
Expand Down Expand Up @@ -600,6 +602,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto rsqrtOp = dyn_cast<RsqrtOp>(op); rsqrtOp) {
return createOperation(cache, createEltwiseOp(cache, rsqrtOp), debugString);
}
if (auto signOp = dyn_cast<SignOp>(op); signOp) {
return createOperation(cache, createEltwiseOp(cache, signOp), debugString);
}
if (auto expOp = dyn_cast<ExpOp>(op); expOp) {
return createOperation(cache, createEltwiseOp(cache, expOp), debugString);
}
Expand Down
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::reciprocal);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Sign: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::sign);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Exp: {
runEltwiseUnaryWithFastAndApproximateModeOP(op, tensorPool, ::ttnn::exp);
break;
Expand Down
12 changes: 12 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/sign_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_sign attributes {} {
func.func public @test_sign(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = stablehlo.sign %arg0 : tensor<13x21x3xf32>
// CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]]
// CHECK: [[VAL1:%[0-9]+]] = "ttir.sign"(%arg0, [[VAL0]]) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]]
return %0 : tensor<13x21x3xf32>
// CHECK: return [[VAL1]] : [[TENSOR_SIZE]]
}
}
12 changes: 12 additions & 0 deletions test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}>
%1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>
return %1 : tensor<64x128xf32>
// CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}>
}
}
14 changes: 14 additions & 0 deletions test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>

func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}>
%1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>
return %1 : tensor<64x128xf32>
// CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}>
}
9 changes: 9 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,12 @@ func.func @expm1(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
return %1 : tensor<64x128xf32>
// CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}>
}

func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}>
%1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>
return %1 : tensor<64x128xf32>
// CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}>
}

0 comments on commit 25191b6

Please sign in to comment.