Skip to content

Commit

Permalink
Add an preliminary conversion for stablehlo.dot_general op (#656)
Browse files Browse the repository at this point in the history
* Add conversion for stablehlo.dot_general to ttir.matmul.

* Improve error checking and updated all legality checks for stablehlo conversions.
  • Loading branch information
uazizTT authored Sep 12, 2024
1 parent 82eae0c commit b39cc58
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 10 deletions.
119 changes: 109 additions & 10 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ class StableHLOToTTIRReduceOpConversionPattern
matchAndRewrite(mlir::stablehlo::ReduceOp srcOp,
mlir::stablehlo::ReduceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!checkBasicLegality(srcOp)) {
return failure();
LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
}

const mlir::Operation &innerOp = srcOp.getBody().front().front();
Expand All @@ -76,18 +77,22 @@ class StableHLOToTTIRReduceOpConversionPattern
}

private:
bool checkBasicLegality(mlir::stablehlo::ReduceOp &srcOp) const {
LogicalResult checkBasicLegality(mlir::stablehlo::ReduceOp &srcOp,
mlir::stablehlo::ReduceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!srcOp.getBody().hasOneBlock()) {
// Expecting StableHLO Reduce OP to have one block inside its body.
return false;
return rewriter.notifyMatchFailure(
srcOp,
"Expecting StableHLO Reduce OP to have one block inside its body.");
}

if (srcOp.getBody().front().empty()) {
// Expecting StableHLO Reduce OP to have a body operation defined.
return false;
return rewriter.notifyMatchFailure(
srcOp,
"Expecting StableHLO Reduce OP to have a body operation defined.");
}

return true;
return success();
}

template <typename DestOp>
Expand Down Expand Up @@ -133,8 +138,10 @@ class StableHLOToTTIRTransposeOpConversionPattern
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

assert(adaptor.getPermutation().size() == 2 &&
"TTIR only supports only two dimensional transposeOp.");
LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
}

rewriter.replaceOpWithNewOp<mlir::tt::ttir::TransposeOp>(
srcOp, outputTensor.getType(), Value(adaptor.getOperand()),
Expand All @@ -146,6 +153,90 @@ class StableHLOToTTIRTransposeOpConversionPattern
OperandConstraint::AnyDeviceTile))));
return success();
}

LogicalResult
checkBasicLegality(mlir::stablehlo::TransposeOp &srcOp,
mlir::stablehlo::TransposeOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter) const {

if (adaptor.getPermutation().size() != 2) {
return rewriter.notifyMatchFailure(
srcOp, "TTIR supports only two dimensional transposeOp.");
}

return success();
}
};

class StableHLOToTTIRDotGeneralOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::DotGeneralOp> {
using OpConversionPattern<mlir::stablehlo::DotGeneralOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::DotGeneralOp srcOp,
mlir::stablehlo::DotGeneralOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outputType = mlir::cast<RankedTensorType>(srcOp.getResult().getType());
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

// This is a basic version that can only work for cases that can be directly
// converted to matmul. The op should be extended as other ops such as
// ttir.permute and ttir.broadcast_in_dim become available.

LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
}

rewriter.replaceOpWithNewOp<mlir::tt::ttir::MatmulOp>(
srcOp, outputTensor.getType(), adaptor.getLhs(), adaptor.getRhs(),
Value(outputTensor),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}

private:
LogicalResult
checkBasicLegality(mlir::stablehlo::DotGeneralOp &srcOp,
mlir::stablehlo::DotGeneralOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter) const {

::mlir::stablehlo::DotDimensionNumbersAttr dimensions =
adaptor.getDotDimensionNumbers();

if (dimensions.getLhsContractingDimensions().empty() ||
dimensions.getRhsContractingDimensions().empty()) {
return rewriter.notifyMatchFailure(srcOp,
"Contracting dimension is missing.");
}

if (dimensions.getLhsContractingDimensions()[0] != 1) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
}

if (dimensions.getRhsContractingDimensions()[0] != 0) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
}

if (not dimensions.getLhsBatchingDimensions().empty()) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
}

if (not dimensions.getRhsBatchingDimensions().empty()) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
}

return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
Expand Down Expand Up @@ -188,6 +279,13 @@ void addTransposeOpsConversionPatterns(MLIRContext *ctx,
patterns.add<StableHLOToTTIRTransposeOpConversionPattern>(typeConverter, ctx);
}

void addMatmulOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRDotGeneralOpConversionPattern>(typeConverter,
ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -199,6 +297,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter);
addReduceOpsConversionPatterns(ctx, patterns, typeConverter);
addTransposeOpsConversionPatterns(ctx, patterns, typeConverter);
addMatmulOpsConversionPatterns(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
10 changes: 10 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/dot_general_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_dot_general attributes {} {
func.func public @test_dot_general(%arg0 : tensor<16x32xf32>, %arg1 : tensor<32x8xf32>) -> tensor<16x8xf32> {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<16x32xf32>, tensor<32x8xf32>) -> tensor<16x8xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]]
return %0 : tensor<16x8xf32>
}
}

0 comments on commit b39cc58

Please sign in to comment.