Skip to content

Commit

Permalink
Add empty lower add pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Oct 19, 2023
1 parent 35db35a commit 282dcb1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/Conversion/PolyToStandard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ cc_library(
"pass_inc_gen",
"//lib/Dialect/Poly",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
Expand Down
44 changes: 44 additions & 0 deletions lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include "lib/Dialect/Poly/PolyOps.h"
#include "lib/Dialect/Poly/PolyTypes.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir {
Expand Down Expand Up @@ -32,6 +34,22 @@ class PolyToStandardTypeConverter : public TypeConverter {
}
};

struct ConvertAdd : public OpConversionPattern<AddOp> {
ConvertAdd(mlir::MLIRContext *context)
: OpConversionPattern<AddOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
arith::AddIOp addOp = rewriter.create<arith::AddIOp>(
op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op.getOperation(), {addOp});
return success();
}
};

struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
using PolyToStandardBase::PolyToStandardBase;

Expand All @@ -40,9 +58,35 @@ struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
auto *module = getOperation();

ConversionTarget target(*context);
target.addLegalDialect<arith::ArithDialect>();
target.addIllegalDialect<PolyDialect>();

RewritePatternSet patterns(context);
PolyToStandardTypeConverter typeConverter(context);
patterns.add<ConvertAdd>(typeConverter, context);

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});

populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<func::ReturnOp>(
[&](func::ReturnOp op) { return typeConverter.isLegal(op); });

populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return typeConverter.isLegal(op); });

populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
isLegalForBranchOpInterfaceTypeConversionPattern(op,
typeConverter) ||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
});

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
Expand Down
8 changes: 8 additions & 0 deletions tests/poly_to_standard.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: tutorial-opt --poly-to-standard %s | FileCheck %s

// CHECK-LABEL: test_lower_add
func.func @test_lower_add(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> {
// CHECK: arith.addi
%2 = poly.add %0, %1: !poly.poly<10>
return %2 : !poly.poly<10>
}

0 comments on commit 282dcb1

Please sign in to comment.