Skip to content

Commit

Permalink
Added support for stablehlo.convert op
Browse files Browse the repository at this point in the history
* stablehlo.convert op is used to convert data type of a variable
* Introduced a new TTIR::cast Op to map stablehlo.convert which is used instead
of TTIR::to_layout op to avoid adding layout information in TTIR mlir graph.
* TTIR::cast op is then lowered to TTNN::to_layout
  • Loading branch information
mmanzoorTT committed Sep 25, 2024
1 parent 9abcba1 commit 9034e8a
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 3 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
}];
}

def TTIR_TypecastOp: TTIR_ElementwiseUnaryOp<"typecast"> {
let summary = "Eltwise cast op.";
let description = [{
Eltwise cast operation.
}];
}

def TTIR_SqrtOp : TTIR_ElementwiseUnaryOp<"sqrt"> {
let summary = "Eltwise square root.";
let description = [{
Expand Down
9 changes: 6 additions & 3 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,16 @@ void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::AbsOp, mlir::tt::ttir::AbsOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SqrtOp, mlir::tt::ttir::SqrtOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::RsqrtOp, mlir::tt::ttir::RsqrtOp>>(typeConverter, ctx);
mlir::stablehlo::ConvertOp, mlir::tt::ttir::TypecastOp>>(typeConverter,
ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::ExpOp, mlir::tt::ttir::ExpOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::NegOp, mlir::tt::ttir::NegOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::RsqrtOp, mlir::tt::ttir::RsqrtOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SqrtOp, mlir::tt::ttir::SqrtOp>>(typeConverter, ctx);
}

void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
Expand Down
32 changes: 32 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,37 @@ class ConstantOpConversionPattern
}
};

class TypecastOpConversionPattern
: public OpConversionPattern<ttir::TypecastOp> {
public:
using OpConversionPattern<ttir::TypecastOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::TypecastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
mlir::TypedValue<mlir::RankedTensorType> result =
mlir::cast<mlir::TypedValue<mlir::RankedTensorType>>(op->getOperand(1));
mlir::TypedValue<mlir::RankedTensorType> input =
mlir::cast<mlir::TypedValue<mlir::RankedTensorType>>(op->getOperand(0));

mlir::Value device = getOrInsertDevice(rewriter, op);
tt::LayoutAttr ttLayoutAttr =
mlir::cast<tt::LayoutAttr>(result.getType().getEncoding());
Type elementType = ttLayoutAttr.getMemref().getElementType();
ttnn::Layout ttnnLayoutEnum = ttnn::Layout::RowMajor;
if (llvm::isa<TileType>(elementType)) {
ttnnLayoutEnum = ttnn::Layout::Tile;
}

ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(op.getContext(), ttnnLayoutEnum);

rewriter.replaceOpWithNewOp<ttnn::ToLayoutOp>(op, result.getType(), input,
device, tensorLayoutAttr);
return success();
}
};

} // namespace

// ANCHOR: adding_an_op_matmul_op_rewriter
Expand Down Expand Up @@ -677,6 +708,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
EmbeddingOpConversionPattern,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
TypecastOpConversionPattern,
ConcatOpConversionPattern,
ReshapeOpConversionPattern,
SqueezeOpConversionPattern,
Expand Down
26 changes: 26 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/convert_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_convert attributes {} {
func.func public @test_convert(%arg0: tensor<2x4xf32>) -> tensor<2x4xbf16> {
%0 = stablehlo.convert %arg0 : (tensor<2x4xf32>) -> tensor<2x4xbf16>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.typecast"
// CHECK-SAME: (tensor<2x4xf32>, tensor<2x4xbf16>) -> tensor<2x4xbf16>
return %0 : tensor<2x4xbf16>
}
}

module @jit_eltwise_add attributes {} {
func.func public @test_add(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = stablehlo.convert %arg0 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[ARG1:.*]] = "ttir.typecast"[[C:.*]]
%1 = stablehlo.convert %arg1 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[ARG2:.*]] = "ttir.typecast"[[C:.*]]
%2 = stablehlo.add %0, %1 : tensor<13x21x3xf32>
// CHECK: %[[C:.*]] = "ttir.add"(%[[ARG1]], %[[ARG2]],
return %2 : tensor<13x21x3xf32>
}
}
13 changes: 13 additions & 0 deletions test/ttmlir/Dialect/TTNN/eltwise/unary/cast/simple_cast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> {
%0 = tensor.empty() : tensor<64x128xbf16>
// CHECK: {{.*}} = "ttnn.empty"{{.*}}
%1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.to_layout"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xbf16,
return %1 : tensor<64x128xbf16>
}
}

0 comments on commit 9034e8a

Please sign in to comment.