Skip to content

Commit

Permalink
Added conversion from stablehlo.transposeOp to ttir.transposeOp (#641)
Browse files Browse the repository at this point in the history
* Added conversion from stablehlo.transposeOp to ttir.transposeOp with assertion to limit conversion to two-dimensional inputs only. 

* Refactor reduce op conversion pattern to concrete class and renamed stablehlo tests.
  • Loading branch information
uazizTT authored Sep 9, 2024
1 parent 79d2f52 commit 97fdf8d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
53 changes: 44 additions & 9 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ class StableHLOToTTIROpDefaultConversionPattern
}
};

template <typename SrcOp, typename Adaptor = typename SrcOp::Adaptor>
class StableHLOToTTIRReduceOpConversionPattern
: public OpConversionPattern<SrcOp> {
: public OpConversionPattern<mlir::stablehlo::ReduceOp> {

using OpConversionPattern<SrcOp>::OpConversionPattern;
using OpConversionPattern<mlir::stablehlo::ReduceOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(SrcOp srcOp, Adaptor adaptor,
matchAndRewrite(mlir::stablehlo::ReduceOp srcOp,
mlir::stablehlo::ReduceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!checkBasicLegality(srcOp)) {
return failure();
Expand All @@ -76,7 +76,7 @@ class StableHLOToTTIRReduceOpConversionPattern
}

private:
bool checkBasicLegality(SrcOp &srcOp) const {
bool checkBasicLegality(mlir::stablehlo::ReduceOp &srcOp) const {
if (!srcOp.getBody().hasOneBlock()) {
// Expecting StableHLO Reduce OP to have one block inside its body.
return false;
Expand All @@ -92,7 +92,8 @@ class StableHLOToTTIRReduceOpConversionPattern

template <typename DestOp>
LogicalResult
matchAndRewriteInternal(SrcOp &srcOp, Adaptor &adaptor,
matchAndRewriteInternal(mlir::stablehlo::ReduceOp &srcOp,
mlir::stablehlo::ReduceOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter) const {
auto outputType =
mlir::cast<RankedTensorType>(srcOp.getResultTypes().front());
Expand All @@ -119,6 +120,34 @@ class StableHLOToTTIRReduceOpConversionPattern
}
};

class StableHLOToTTIRTransposeOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::TransposeOp> {
using OpConversionPattern<mlir::stablehlo::TransposeOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::TransposeOp srcOp,
mlir::stablehlo::TransposeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outputType = mlir::cast<RankedTensorType>(srcOp.getResult().getType());
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

assert(adaptor.getPermutation().size() == 2 &&
"TTIR only supports only two dimensional transposeOp.");

rewriter.replaceOpWithNewOp<mlir::tt::ttir::TransposeOp>(
srcOp, outputTensor.getType(), Value(adaptor.getOperand()),
Value(outputTensor), adaptor.getPermutation()[0],
adaptor.getPermutation()[1],
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -147,9 +176,14 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
void addReduceOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns
.add<StableHLOToTTIRReduceOpConversionPattern<mlir::stablehlo::ReduceOp>>(
typeConverter, ctx);
patterns.add<StableHLOToTTIRReduceOpConversionPattern>(typeConverter, ctx);
}

void addTransposeOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {

patterns.add<StableHLOToTTIRTransposeOpConversionPattern>(typeConverter, ctx);
}

} // namespace
Expand All @@ -162,6 +196,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter);
addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter);
addReduceOpsConversionPatterns(ctx, patterns, typeConverter);
addTransposeOpsConversionPatterns(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
10 changes: 10 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/tranpose_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_transpose attributes {} {
func.func public @test_transpose(%arg0: tensor<64x128xf32>) -> tensor<128x64xf32> {
%0 = stablehlo.transpose %arg0, dims = [1,0] : (tensor<64x128xf32>) -> tensor<128x64xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]]
return %0 : tensor<128x64xf32>
}
}

0 comments on commit 97fdf8d

Please sign in to comment.