Skip to content

Commit

Permalink
Add sigmoid op (#469)
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT authored Aug 23, 2024
1 parent cfff32f commit 855d49e
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/src/adding-an-op.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Invoked as part of the rewrite set:
MatmulOpConversionPattern
```

### Note:
We also need to add this op to the C++ emitter,
`lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp` see
`populateTTNNToEmitCPatterns(...)`.
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> {
}];
}

def TTIR_SigmoidOp: TTIR_ElementwiseUnaryOp<"sigmoid"> {
let summary = "Eltwise sigmoid.";
let description = [{
Eltwise sigmoid operation.
}];
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> : TTIR_DPSOp<mnemonic, traits> {
let summary = "Reduction op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu"> {
}];
}

def TTNN_SigmoidOp : TTNN_ElementwiseUnaryOp<"sigmoid"> {
let summary = "Eltwise sigmoid.";
let description = [{
Eltwise sigmoid operation.
}];
}

def TTNN_AddOp : TTNN_ElementwiseBinaryOp<"add"> {
let summary = "Eltwise add.";
let description = [{
Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ enum EltwiseOpType: uint32 {
Relu = 3,
GreaterEqual = 4,
Sqrt = 5,
Div = 6
Div = 6,
Sigmoid = 7,
}

table EltwiseOp {
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseBinaryOpConversionPattern<ttir::GreaterEqualOp, ttnn::GreaterEqualOp>,
ElementwiseBinaryOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ElementwiseBinaryOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
ElementwiseBinaryOpConversionPattern<ttir::SigmoidOp, ttnn::SigmoidOp>,
ElementwiseBinaryOpConversionPattern<ttir::DivOp, ttnn::DivOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
Expand Down
5 changes: 4 additions & 1 deletion lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::ReluOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::SqrtOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::SoftmaxOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::SigmoidOp>>(typeConverter, ctx);

// Eltwise binary ops
//
Expand All @@ -197,6 +197,9 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::SumOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MeanOp>>(typeConverter, ctx);

// Other ops
patterns.add<DefaultOpConversionPattern<ttnn::SoftmaxOp>>(typeConverter, ctx);
}

} // namespace mlir::tt
6 changes: 6 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Sqrt;
} else if constexpr (std::is_same_v<EltwiseOp, DivOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Div;
} else if constexpr (std::is_same_v<EltwiseOp, SigmoidOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Sigmoid;
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -242,6 +244,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto sqrtOp = dyn_cast<SqrtOp>(op); sqrtOp) {
return createOperation(cache, createEltwiseOp(cache, sqrtOp), debugString);
}
if (auto sigmoidOp = dyn_cast<SigmoidOp>(op); sigmoidOp) {
return createOperation(cache, createEltwiseOp(cache, sigmoidOp),
debugString);
}
if (auto divOp = dyn_cast<DivOp>(op); divOp) {
return createOperation(cache, createEltwiseOp(cache, divOp), debugString);
}
Expand Down
11 changes: 9 additions & 2 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device,
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::EltwiseOpType::Sigmoid: {
assert(op->ins()->size() == 1 && "Unsupported number of inputs");
::ttnn::Tensor &in = *liveTensors.at(op->ins()->Get(0)->global_id());
tensorPool.push_back(::ttnn::sigmoid(in));
liveTensors.insert_or_assign(op->out()->global_id(), &tensorPool.back());
break;
}
}
}

Expand Down Expand Up @@ -296,7 +303,7 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device,
}

static void
run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::device::Device &device,
run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
::ttnn::Tensor &in = *liveTensors.at(op->in()->global_id());
Expand All @@ -307,7 +314,7 @@ run(::tt::target::ttnn::SoftmaxOp const *op, ::ttnn::device::Device &device,
}

static void
run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::device::Device &device,
run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
::ttnn::Tensor &in = *liveTensors.at(op->in()->global_id());
Expand Down
13 changes: 13 additions & 0 deletions test/ttmlir/Dialect/TTNN/eltwise/unary/sigmoid/simple_sigmoid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.sigmoid"[[C:.*]]
%1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<64x128xf32>
}
}

0 comments on commit 855d49e

Please sign in to comment.