diff --git a/lib/Conversion/PolyToStandard/BUILD b/lib/Conversion/PolyToStandard/BUILD index b27628e..762126a 100644 --- a/lib/Conversion/PolyToStandard/BUILD +++ b/lib/Conversion/PolyToStandard/BUILD @@ -31,6 +31,8 @@ cc_library( "pass_inc_gen", "//lib/Dialect/Poly", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TensorDialect", diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.cpp b/lib/Conversion/PolyToStandard/PolyToStandard.cpp index eeaadf5..c18fba8 100644 --- a/lib/Conversion/PolyToStandard/PolyToStandard.cpp +++ b/lib/Conversion/PolyToStandard/PolyToStandard.cpp @@ -2,6 +2,8 @@ #include "lib/Dialect/Poly/PolyOps.h" #include "lib/Dialect/Poly/PolyTypes.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -32,6 +34,22 @@ class PolyToStandardTypeConverter : public TypeConverter { } }; +struct ConvertAdd : public OpConversionPattern { + ConvertAdd(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + AddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + arith::AddIOp addOp = rewriter.create( + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op.getOperation(), {addOp}); + return success(); + } +}; + struct PolyToStandard : impl::PolyToStandardBase { using PolyToStandardBase::PolyToStandardBase; @@ -40,9 +58,35 @@ struct PolyToStandard : impl::PolyToStandardBase { auto *module = getOperation(); ConversionTarget target(*context); + target.addLegalDialect(); target.addIllegalDialect(); RewritePatternSet patterns(context); + PolyToStandardTypeConverter typeConverter(context); + patterns.add(typeConverter, context); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + + populateReturnOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + + populateCallOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter); + }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); diff --git a/tests/poly_to_standard.mlir b/tests/poly_to_standard.mlir new file mode 100644 index 0000000..94aa699 --- /dev/null +++ b/tests/poly_to_standard.mlir @@ -0,0 +1,8 @@ +// RUN: tutorial-opt --poly-to-standard %s | FileCheck %s + +// CHECK-LABEL: test_lower_add +func.func @test_lower_add(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> { + // CHECK: arith.addi + %2 = poly.add %0, %1: !poly.poly<10> + return %2 : !poly.poly<10> +}