diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 2bba60220..60304efe7 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -47,15 +47,15 @@ class StableHLOToTTIROpDefaultConversionPattern } }; -template class StableHLOToTTIRReduceOpConversionPattern - : public OpConversionPattern { + : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; public: LogicalResult - matchAndRewrite(SrcOp srcOp, Adaptor adaptor, + matchAndRewrite(mlir::stablehlo::ReduceOp srcOp, + mlir::stablehlo::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!checkBasicLegality(srcOp)) { return failure(); @@ -76,7 +76,7 @@ class StableHLOToTTIRReduceOpConversionPattern } private: - bool checkBasicLegality(SrcOp &srcOp) const { + bool checkBasicLegality(mlir::stablehlo::ReduceOp &srcOp) const { if (!srcOp.getBody().hasOneBlock()) { // Expecting StableHLO Reduce OP to have one block inside its body. return false; @@ -92,7 +92,8 @@ class StableHLOToTTIRReduceOpConversionPattern template LogicalResult - matchAndRewriteInternal(SrcOp &srcOp, Adaptor &adaptor, + matchAndRewriteInternal(mlir::stablehlo::ReduceOp &srcOp, + mlir::stablehlo::ReduceOp::Adaptor &adaptor, ConversionPatternRewriter &rewriter) const { auto outputType = mlir::cast(srcOp.getResultTypes().front()); @@ -119,6 +120,34 @@ class StableHLOToTTIRReduceOpConversionPattern } }; +class StableHLOToTTIRTransposeOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::TransposeOp srcOp, + mlir::stablehlo::TransposeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto outputType = mlir::cast(srcOp.getResult().getType()); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + assert(adaptor.getPermutation().size() == 2 && + "TTIR only supports only two dimensional transposeOp."); + + rewriter.replaceOpWithNewOp( + srcOp, outputTensor.getType(), Value(adaptor.getOperand()), + Value(outputTensor), adaptor.getPermutation()[0], + adaptor.getPermutation()[1], + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -147,9 +176,14 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, void addReduceOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { - patterns - .add>( - typeConverter, ctx); + patterns.add(typeConverter, ctx); +} + +void addTransposeOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + + patterns.add(typeConverter, ctx); } } // namespace @@ -162,6 +196,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter); addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter); addReduceOpsConversionPatterns(ctx, patterns, typeConverter); + addTransposeOpsConversionPatterns(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/tranpose_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/tranpose_op.mlir new file mode 100644 index 000000000..398746dda --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/tranpose_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_transpose attributes {} { + func.func public @test_transpose(%arg0: tensor<64x128xf32>) -> tensor<128x64xf32> { + %0 = stablehlo.transpose %arg0, dims = [1,0] : (tensor<64x128xf32>) -> tensor<128x64xf32> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]] + return %0 : tensor<128x64xf32> + } +}