diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index f74341f4e..6beaa7460 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -281,6 +281,18 @@ def TTIR_LogOp: TTIR_ElementwiseUnaryOp<"log"> { }]; } +def TTIR_Log1pOp: TTIR_ElementwiseUnaryOp<"log1p"> { + let summary = "Eltwise log1p operation."; + let description = [{ + Performs element-wise logarithm plus one operation on `operand` tensor and + puts the result in the output tensor. + + Example: + %a: [0.0, -0.999, 7.0, 6.38905621, 15.0] + "ttir.logp1"(%a, %out) -> %out: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873] + }]; +} + class TTIR_ElementwiseBinaryOp traits = []> : TTIR_ElementwiseOp { let summary = "Eltwise binary op."; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index e663dd018..79b1f8570 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -244,6 +244,18 @@ def TTNN_LogOp : TTNN_ElementwiseUnaryOp<"log"> { }]; } +def TTNN_Log1pOp: TTNN_ElementwiseUnaryOp<"log1p"> { + let summary = "Eltwise log1p operation."; + let description = [{ + Performs element-wise logarithm plus one operation on `operand` tensor and + puts the result in the output tensor. + + Example: + %a: [0.0, -0.999, 7.0, 6.38905621, 15.0] + "ttnn.logp1"(%a, %out) -> %out: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873] + }]; +} + def TTNN_AddOp : TTNN_ElementwiseBinaryOp<"add"> { let summary = "Eltwise add."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 8e3facea2..44ae9d8c2 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -86,7 +86,8 @@ enum EltwiseOpType: uint32 { Ceil = 25, Sin = 26, Cos = 27, - Log = 28 + Log = 28, + Log1p = 29, } union EltwiseOpParams { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 4cb2c3efa..d94e66fe7 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -859,6 +859,8 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, mlir::stablehlo::SineOp, mlir::tt::ttir::SinOp>>(typeConverter, ctx); patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); } void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 7cefdbb79..27543b721 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -877,6 +877,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, + ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, ElementwiseOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 70f6dd4d5..561830099 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -622,6 +622,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 1cd080cc2..0fbc52708 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -343,6 +343,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Div; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Sigmoid; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Log1p; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Exp; } else if constexpr (std::is_same_v) { @@ -606,6 +608,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createEltwiseOp(cache, sigmoidOp), debugString); } + if (auto log1pOp = dyn_cast(op); log1pOp) { + return createOperation(cache, createEltwiseOp(cache, log1pOp), debugString); + } if (auto reciprocalOp = dyn_cast(op); reciprocalOp) { return createOperation(cache, createEltwiseOp(cache, reciprocalOp), debugString); diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp index c5b5ff6e4..e5b3216bd 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary.cpp @@ -86,10 +86,12 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseUnaryOP(op, tensorPool, ::ttnn::sigmoid); break; } + case ::tt::target::ttnn::EltwiseOpType::Sin: { runEltwiseUnaryOP(op, tensorPool, ::ttnn::sin); break; } + case ::tt::target::ttnn::EltwiseOpType::Reciprocal: { runEltwiseUnaryOP(op, tensorPool, ::ttnn::reciprocal); break; diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp index da4af9c63..78b23ce0e 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp @@ -33,6 +33,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt); break; } + case ::tt::target::ttnn::EltwiseOpType::Log1p: { + runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::log1p); + break; + } default: throw std::invalid_argument( "Unsupported Eltwise Binary Composite operation"); diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h index 11231492e..d40f32ffe 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h @@ -14,6 +14,8 @@ inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { switch (op->type()) { case ::tt::target::ttnn::EltwiseOpType::Cbrt: return true; + case ::tt::target::ttnn::EltwiseOpType::Log1p: + return true; default: return false; } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/log_plus_one_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/log_plus_one_op.mlir new file mode 100644 index 000000000..d1d44f3af --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/log_plus_one_op.mlir @@ -0,0 +1,12 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module @jit_eltwise_log_plus_one attributes {} { + func.func public @test_log_plus_one(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = stablehlo.log_plus_one %arg0 : tensor<13x21x3xf32> + // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.log1p"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + return %0 : tensor<13x21x3xf32> + // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] + } +} diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir new file mode 100644 index 000000000..b65aa3c21 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> + %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + return %1 : tensor<64x128xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> + } +} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir new file mode 100644 index 000000000..2c32cc817 --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir @@ -0,0 +1,15 @@ + +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @log1p(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> + %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + return %1 : tensor<64x128xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index 5b683e276..71bafe10d 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -191,3 +191,12 @@ func.func @log(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<6 %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } + +func.func @log1p(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> + %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> + return %1 : tensor<64x128xf32> + // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> +}