diff --git a/lib/Conversion/PolyToStandard/BUILD b/lib/Conversion/PolyToStandard/BUILD index b27628e..762126a 100644 --- a/lib/Conversion/PolyToStandard/BUILD +++ b/lib/Conversion/PolyToStandard/BUILD @@ -31,6 +31,8 @@ cc_library( "pass_inc_gen", "//lib/Dialect/Poly", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TensorDialect", diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.cpp b/lib/Conversion/PolyToStandard/PolyToStandard.cpp index 8f81f13..9444050 100644 --- a/lib/Conversion/PolyToStandard/PolyToStandard.cpp +++ b/lib/Conversion/PolyToStandard/PolyToStandard.cpp @@ -2,6 +2,8 @@ #include "lib/Dialect/Poly/PolyOps.h" #include "lib/Dialect/Poly/PolyTypes.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -63,6 +65,29 @@ struct PolyToStandard : impl::PolyToStandardBase { PolyToStandardTypeConverter typeConverter(context); patterns.add(typeConverter, context); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + + populateReturnOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + + populateCallOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter); + }); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); }