diff --git a/docs/src/adding-an-op.md b/docs/src/adding-an-op.md index 9eea62adf..9108abe81 100644 --- a/docs/src/adding-an-op.md +++ b/docs/src/adding-an-op.md @@ -147,6 +147,7 @@ Invoked as part of the rewrite set: MatmulOpConversionPattern ``` +### Note: We also need to add this op to the C++ emitter, `lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp` see `populateTTNNToEmitCPatterns(...)`. diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index f9f7deb25..eff965116 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -252,6 +252,13 @@ def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> { }]; } +def TTIR_SigmoidOp: TTIR_ElementwiseUnaryOp<"sigmoid"> { + let summary = "Eltwise sigmoid."; + let description = [{ + Eltwise sigmoid operation. + }]; +} + class TTIR_ReductionOp traits = []> : TTIR_DPSOp { let summary = "Reduction op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index a39054c96..95f9bc3d9 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -104,6 +104,13 @@ def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu"> { }]; } +def TTNN_SigmoidOp : TTNN_ElementwiseUnaryOp<"sigmoid"> { + let summary = "Eltwise sigmoid."; + let description = [{ + Eltwise sigmoid operation. + }]; +} + def TTNN_AddOp : TTNN_ElementwiseBinaryOp<"add"> { let summary = "Eltwise add."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index f77e8aeda..65e125ebc 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -36,7 +36,8 @@ enum EltwiseOpType: uint32 { Relu = 3, GreaterEqual = 4, Sqrt = 5, - Div = 6 + Div = 6, + Sigmoid = 7, } table EltwiseOp { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 0292c79f0..3e035faa2 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -174,6 +174,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseBinaryOpConversionPattern, ElementwiseBinaryOpConversionPattern, ElementwiseBinaryOpConversionPattern, + ElementwiseBinaryOpConversionPattern, ElementwiseBinaryOpConversionPattern, ReductionOpConversionPattern, ReductionOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 317a82174..87ca7603d 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -171,7 +171,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // patterns.add>(typeConverter, ctx); patterns.add>(typeConverter, ctx); - patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); // Eltwise binary ops // @@ -197,6 +197,9 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // patterns.add>(typeConverter, ctx); patterns.add>(typeConverter, ctx); + + // Other ops + patterns.add>(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index a351e0c82..cdfb01f63 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -139,6 +139,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Sqrt; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Div; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Sigmoid; } else { llvm_unreachable("unhandled EltwiseOp"); } @@ -242,6 +244,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto sqrtOp = dyn_cast(op); sqrtOp) { return createOperation(cache, createEltwiseOp(cache, sqrtOp), debugString); } + if (auto sigmoidOp = dyn_cast(op); sigmoidOp) { + return createOperation(cache, createEltwiseOp(cache, sigmoidOp), + debugString); + } if (auto divOp = dyn_cast(op); divOp) { return createOperation(cache, createEltwiseOp(cache, divOp), debugString); } diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 9fa6f195a..189c3fb25 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -256,6 +256,13 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device, liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); break; } + case ::tt::target::ttnn::EltwiseOpType::Sigmoid: { + assert(op->ins()->size() == 1 && "Unsupported number of inputs"); + ::ttnn::Tensor &in = *liveTensors.at(op->ins()->Get(0)->global_id()); + tensorPool.push_back(::ttnn::sigmoid(in)); + liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back()); + break; + } } } @@ -296,7 +303,7 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device, } static void -run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::device::Device &device, +run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::Device &device, std::unordered_map &liveTensors, std::list<::ttnn::Tensor> &tensorPool) { ::ttnn::Tensor &in = *liveTensors.at(op->in()->global_id()); @@ -307,7 +314,7 @@ run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::device::Device &device, } static void -run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::device::Device &device, +run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::Device &device, std::unordered_map &liveTensors, std::list<::ttnn::Tensor> &tensorPool) { ::ttnn::Tensor &in = *liveTensors.at(op->in()->global_id()); diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/sigmoid/simple_sigmoid.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/sigmoid/simple_sigmoid.mlir new file mode 100644 index 000000000..7b99e6db1 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/sigmoid/simple_sigmoid.mlir @@ -0,0 +1,13 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.sigmoid"[[C:.*]] + %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: "ttnn.close_device"[[C:.*]] + return %1 : tensor<64x128xf32> + } +}