Skip to content

Commit

Permalink
Add cbrt end to end
Browse files Browse the repository at this point in the history
  • Loading branch information
ddilbazTT committed Oct 24, 2024
1 parent 2e8bc5a commit 871f083
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
}];
}

def TTIR_CbrtOp: TTIR_ElementwiseUnaryOp<"cbrt"> {
let summary = "Eltwise cubic root op.";
let description = [{
Eltwise cubic root operation.
}];
}


def TTIR_TypecastOp: TTIR_ElementwiseUnaryOp<"typecast"> {
let summary = "Eltwise cast op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> {
}];
}

def TTNN_CbrtOp : TTNN_ElementwiseUnaryOp<"cbrt"> {
let summary = "Eltwise cubic root.";
let description = [{
Eltwise cubic root operation.
}];
}

def TTNN_SqrtOp : TTNN_ElementwiseUnaryOp<"sqrt"> {
let summary = "Eltwise sqrt.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum EltwiseOpType: uint32 {
LogicalAnd = 20,
LogicalOr = 21,
LogicalNot = 22,
Cbrt = 23,
}

table EltwiseOp {
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,9 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,

patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::AbsOp, mlir::tt::ttir::AbsOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::CbrtOp, mlir::tt::ttir::CbrtOp>>(typeConverter,
ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::ConvertOp, mlir::tt::ttir::TypecastOp>>(typeConverter,
ctx);
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ToLayoutOpConversionPattern,
ElementwiseOpConversionPattern<ttir::AbsOp, ttnn::AbsOp>,
ElementwiseOpConversionPattern<ttir::AddOp, ttnn::AddOp>,
ElementwiseOpConversionPattern<ttir::CbrtOp, ttnn::CbrtOp>,
ElementwiseOpConversionPattern<ttir::LogicalAndOp, ttnn::LogicalAndOp>,
ElementwiseOpConversionPattern<ttir::LogicalOrOp, ttnn::LogicalOrOp>,
ElementwiseOpConversionPattern<ttir::LogicalNotOp, ttnn::LogicalNotOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Eltwise unary ops
//
patterns.add<DefaultOpConversionPattern<ttnn::AbsOp>,
DefaultOpConversionPattern<ttnn::CbrtOp>,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
DefaultOpConversionPattern<ttnn::NegOp>,
DefaultOpConversionPattern<ttnn::ReluOp>,
Expand Down
5 changes: 5 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Abs;
} else if constexpr (std::is_same_v<EltwiseOp, AddOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Add;
} else if constexpr (std::is_same_v<EltwiseOp, CbrtOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Cbrt;
} else if constexpr (std::is_same_v<EltwiseOp, LogicalAndOp>) {
type = ::tt::target::ttnn::EltwiseOpType::LogicalAnd;
} else if constexpr (std::is_same_v<EltwiseOp, LogicalNotOp>) {
Expand Down Expand Up @@ -484,6 +486,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto andOp = dyn_cast<LogicalAndOp>(op); andOp) {
return createOperation(cache, createEltwiseOp(cache, andOp), debugString);
}
if (auto cbrtOp = dyn_cast<CbrtOp>(op); cbrtOp) {
return createOperation(cache, createEltwiseOp(cache, cbrtOp), debugString);
}
if (auto notOp = dyn_cast<LogicalNotOp>(op); notOp) {
return createOperation(cache, createEltwiseOp(cache, notOp), debugString);
}
Expand Down
21 changes: 21 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/eltwise/unary/unary_composite.hpp"
#include "ttnn/operations/copy.hpp"

namespace tt::runtime::ttnn::operations::unary {
Expand Down Expand Up @@ -37,6 +38,22 @@ static void runEltwiseUnaryOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryCompositeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
::ttnn::Tensor(const ::ttnn::Tensor &, const ::tt::tt_metal::MemoryConfig&)>
ttnnOp) {

::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ttnnOp(*in, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryWithFastAndApproximateModeOP(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
std::function<
Expand Down Expand Up @@ -76,6 +93,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::abs);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Cbrt: {
runEltwiseUnaryCompositeOP(op, tensorPool, ::ttnn::cbrt);
break;
}
case ::tt::target::ttnn::EltwiseOpType::LogicalNot: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::logical_not);
break;
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/cbrt_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_rsqrt attributes {} {
func.func public @test_cbrt(%arg0: tensor<4xf64>) -> tensor<4xf64> {
%0 = stablehlo.cbrt %arg0 : tensor<4xf64>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.cbrt"[[C:.*]]
return %0 : tensor<4xf64>
}
}
11 changes: 11 additions & 0 deletions test/ttmlir/Dialect/TTNN/eltwise/unary/cbrt/simple_cbrt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.cbrt"[[C:.*]]
%1 = "ttir.cbrt"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}
8 changes: 8 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,11 @@ func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tens
%1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}

func.func @cbrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.cbrt"[[C:.*]]
%1 = "ttir.cbrt"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}

0 comments on commit 871f083

Please sign in to comment.