From b39cc58104336b7f10b1d83d2c869a999b2ecfa1 Mon Sep 17 00:00:00 2001 From: Usman Aziz Date: Thu, 12 Sep 2024 15:43:33 -0400 Subject: [PATCH] Add an preliminary conversion for stablehlo.dot_general op (#656) * Add conversion for stablehlo.dot_general to ttir.matmul. * Improve error checking and updated all legality checks for stablehlo conversions. --- .../StableHLOToTTIRPatterns.cpp | 119 ++++++++++++++++-- .../StableHLOToTTIR/dot_general_op.mlir | 10 ++ 2 files changed, 119 insertions(+), 10 deletions(-) create mode 100644 test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 4a9ce5ea9..cb9dc237b 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -57,8 +57,9 @@ class StableHLOToTTIRReduceOpConversionPattern matchAndRewrite(mlir::stablehlo::ReduceOp srcOp, mlir::stablehlo::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!checkBasicLegality(srcOp)) { - return failure(); + LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); + if (not err.succeeded()) { + return err; } const mlir::Operation &innerOp = srcOp.getBody().front().front(); @@ -76,18 +77,22 @@ class StableHLOToTTIRReduceOpConversionPattern } private: - bool checkBasicLegality(mlir::stablehlo::ReduceOp &srcOp) const { + LogicalResult checkBasicLegality(mlir::stablehlo::ReduceOp &srcOp, + mlir::stablehlo::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { if (!srcOp.getBody().hasOneBlock()) { - // Expecting StableHLO Reduce OP to have one block inside its body. - return false; + return rewriter.notifyMatchFailure( + srcOp, + "Expecting StableHLO Reduce OP to have one block inside its body."); } if (srcOp.getBody().front().empty()) { - // Expecting StableHLO Reduce OP to have a body operation defined. - return false; + return rewriter.notifyMatchFailure( + srcOp, + "Expecting StableHLO Reduce OP to have a body operation defined."); } - return true; + return success(); } template @@ -133,8 +138,10 @@ class StableHLOToTTIRTransposeOpConversionPattern tensor::EmptyOp outputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - assert(adaptor.getPermutation().size() == 2 && - "TTIR only supports only two dimensional transposeOp."); + LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); + if (not err.succeeded()) { + return err; + } rewriter.replaceOpWithNewOp( srcOp, outputTensor.getType(), Value(adaptor.getOperand()), @@ -146,6 +153,90 @@ class StableHLOToTTIRTransposeOpConversionPattern OperandConstraint::AnyDeviceTile)))); return success(); } + + LogicalResult + checkBasicLegality(mlir::stablehlo::TransposeOp &srcOp, + mlir::stablehlo::TransposeOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter) const { + + if (adaptor.getPermutation().size() != 2) { + return rewriter.notifyMatchFailure( + srcOp, "TTIR supports only two dimensional transposeOp."); + } + + return success(); + } +}; + +class StableHLOToTTIRDotGeneralOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(mlir::stablehlo::DotGeneralOp srcOp, + mlir::stablehlo::DotGeneralOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto outputType = mlir::cast(srcOp.getResult().getType()); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + // This is a basic version that can only work for cases that can be directly + // converted to matmul. The op should be extended as other ops such as + // ttir.permute and ttir.broadcast_in_dim become available. + + LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); + if (not err.succeeded()) { + return err; + } + + rewriter.replaceOpWithNewOp( + srcOp, outputTensor.getType(), adaptor.getLhs(), adaptor.getRhs(), + Value(outputTensor), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } + +private: + LogicalResult + checkBasicLegality(mlir::stablehlo::DotGeneralOp &srcOp, + mlir::stablehlo::DotGeneralOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter) const { + + ::mlir::stablehlo::DotDimensionNumbersAttr dimensions = + adaptor.getDotDimensionNumbers(); + + if (dimensions.getLhsContractingDimensions().empty() || + dimensions.getRhsContractingDimensions().empty()) { + return rewriter.notifyMatchFailure(srcOp, + "Contracting dimension is missing."); + } + + if (dimensions.getLhsContractingDimensions()[0] != 1) { + return rewriter.notifyMatchFailure( + srcOp, "Only non-transposed matmul is currently supported in TTIR."); + } + + if (dimensions.getRhsContractingDimensions()[0] != 0) { + return rewriter.notifyMatchFailure( + srcOp, "Only non-transposed matmul is currently supported in TTIR."); + } + + if (not dimensions.getLhsBatchingDimensions().empty()) { + return rewriter.notifyMatchFailure( + srcOp, "Only non-transposed matmul is currently supported in TTIR."); + } + + if (not dimensions.getRhsBatchingDimensions().empty()) { + return rewriter.notifyMatchFailure( + srcOp, "Only non-transposed matmul is currently supported in TTIR."); + } + + return success(); + } }; void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, @@ -188,6 +279,13 @@ void addTransposeOpsConversionPatterns(MLIRContext *ctx, patterns.add(typeConverter, ctx); } +void addMatmulOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, + ctx); +} + } // namespace namespace mlir::tt { @@ -199,6 +297,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter); addReduceOpsConversionPatterns(ctx, patterns, typeConverter); addTransposeOpsConversionPatterns(ctx, patterns, typeConverter); + addMatmulOpsConversionPatterns(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir new file mode 100644 index 000000000..da8bf8ee3 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir @@ -0,0 +1,10 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_dot_general attributes {} { + func.func public @test_dot_general(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<16x32xf32>, tensor<32x8xf32>) -> tensor<16x8xf32> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] + return %0 : tensor<16x8xf32> + } +}