Skip to content

Commit

Permalink
Add support for typecast op
Browse files Browse the repository at this point in the history
* Add end-to-end support for typecast op
* Add stablehlo.convert conversion to typecast op
* Add required test cases
  • Loading branch information
mmanzoorTT committed Sep 26, 2024
1 parent 5856bb6 commit 641ef7f
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 3 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
}];
}

def TTIR_TypecastOp: TTIR_ElementwiseUnaryOp<"typecast"> {
let summary = "Eltwise cast op.";
let description = [{
Eltwise cast operation.
}];
}

def TTIR_SqrtOp : TTIR_ElementwiseUnaryOp<"sqrt"> {
let summary = "Eltwise square root.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def TTNN_SigmoidOp : TTNN_ElementwiseUnaryOp<"sigmoid"> {
}];
}

def TTNN_TypecastOp : TTNN_ElementwiseUnaryOp<"typecast"> {
let summary = "Eltwise typecast.";
let description = [{
Eltwise typecast operation.
}];
}

def TTNN_ExpOp : TTNN_ElementwiseUnaryOp<"exp"> {
let summary = "Eltwise exponential.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ enum EltwiseOpType: uint32 {
Abs = 11,
Neg = 12,
Rsqrt = 13,
Typecast = 14,
}

table EltwiseOp {
Expand Down
9 changes: 6 additions & 3 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,16 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::AbsOp, mlir::tt::ttir::AbsOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SqrtOp, mlir::tt::ttir::SqrtOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::RsqrtOp, mlir::tt::ttir::RsqrtOp>>(typeConverter, ctx);
mlir::stablehlo::ConvertOp, mlir::tt::ttir::TypecastOp>>(typeConverter,
ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::ExpOp, mlir::tt::ttir::ExpOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::NegOp, mlir::tt::ttir::NegOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::RsqrtOp, mlir::tt::ttir::RsqrtOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SqrtOp, mlir::tt::ttir::SqrtOp>>(typeConverter, ctx);
}

void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::SqrtOp, ttnn::SqrtOp>,
ElementwiseOpConversionPattern<ttir::RsqrtOp, ttnn::RsqrtOp>,
ElementwiseOpConversionPattern<ttir::SigmoidOp, ttnn::SigmoidOp>,
ElementwiseOpConversionPattern<ttir::TypecastOp, ttnn::TypecastOp>,
ElementwiseOpConversionPattern<ttir::ReciprocalOp, ttnn::ReciprocalOp>,
ElementwiseOpConversionPattern<ttir::ExpOp, ttnn::ExpOp>,
ElementwiseOpConversionPattern<ttir::DivOp, ttnn::DivOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::SqrtOp>,
DefaultOpConversionPattern<ttnn::RsqrtOp>,
DefaultOpConversionPattern<ttnn::SigmoidOp>,
DefaultOpConversionPattern<ttnn::TypecastOp>,
DefaultOpConversionPattern<ttnn::ReciprocalOp>,
DefaultOpConversionPattern<ttnn::ExpOp>>(typeConverter, ctx);

Expand Down
6 changes: 6 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Div;
} else if constexpr (std::is_same_v<EltwiseOp, SigmoidOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Sigmoid;
} else if constexpr (std::is_same_v<EltwiseOp, TypecastOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Typecast;
} else if constexpr (std::is_same_v<EltwiseOp, ExpOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Exp;
} else {
Expand Down Expand Up @@ -501,6 +503,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createTransposeOp(cache, transposeOp),
debugString);
}
if (auto typecastOp = dyn_cast<TypecastOp>(op); typecastOp) {
return createOperation(cache, createEltwiseOp(cache, typecastOp),
debugString);
}
if (auto conv2dOp = dyn_cast<Conv2dOp>(op); conv2dOp) {
return createOperation(cache, createOp(cache, conv2dOp), debugString);
}
Expand Down
18 changes: 18 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "unary.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttnn/operations/copy.hpp"

namespace tt::runtime::ttnn::operations::unary {

Expand Down Expand Up @@ -52,6 +53,19 @@ static void runEltwiseUnaryWithFastAndApproximateModeOP(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runTypecastOp(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool) {
::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOPInputTensor(op, tensorPool, &in);
const DataType outputType = tensorPool.at(op->out()->global_id()).get_dtype();
::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ::ttnn::operations::copy::Typecast().invoke(
*in, outputType, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
assert(isUnaryOp(op) && "Expected binary operation");
ProgramTensorPool &tensorPool = context.getTensorPool();
Expand All @@ -76,6 +90,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::sigmoid);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Typecast: {
runTypecastOp(op, tensorPool);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Reciprocal: {
runEltwiseUnaryOP(op, tensorPool, ::ttnn::reciprocal);
break;
Expand Down
26 changes: 26 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/convert_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_convert attributes {} {
func.func public @test_convert(%arg0: tensor<2x4xf32>) -> tensor<2x4xbf16> {
%0 = stablehlo.convert %arg0 : (tensor<2x4xf32>) -> tensor<2x4xbf16>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.typecast"
// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xbf16>) -> tensor<2x4xbf16>
return %0 : tensor<2x4xbf16>
}
}

module @jit_eltwise_add attributes {} {
func.func public @test_add(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = stablehlo.convert %arg0 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[ARG1:.*]] = "ttir.typecast"[[C:.*]]
%1 = stablehlo.convert %arg1 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[ARG2:.*]] = "ttir.typecast"[[C:.*]]
%2 = stablehlo.add %0, %1 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = "ttir.add"(%[[ARG1]], %[[ARG2]],
return %2 : tensor<13x21x3xf32>
}
}
13 changes: 13 additions & 0 deletions test/ttmlir/Dialect/TTNN/eltwise/unary/cast/simple_cast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> {
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: {{.*}} = "ttnn.empty"{{.*}}
%1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.typecast"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xbf16,
return %1 : tensor<64x128xbf16>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ func.func @softmax(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> {
return %3 : tensor<512x1024xbf16>
}

func.func @typecast(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.typecast"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xbf16,
%1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
return %1 : tensor<64x128xbf16>
}

func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
Expand Down

0 comments on commit 641ef7f

Please sign in to comment.