Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dialect Conversion #20

Merged
merged 15 commits into from
Oct 21, 2023
39 changes: 39 additions & 0 deletions lib/Conversion/PolyToStandard/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_visibility = ["//visibility:public"],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=PolyToStandard",
],
"PolyToStandard.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "PolyToStandard.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

cc_library(
name = "PolyToStandard",
srcs = ["PolyToStandard.cpp"],
hdrs = ["PolyToStandard.h"],
deps = [
"pass_inc_gen",
"//lib/Dialect/Poly",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
72 changes: 72 additions & 0 deletions lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "lib/Conversion/PolyToStandard/PolyToStandard.h"

#include "lib/Dialect/Poly/PolyOps.h"
#include "lib/Dialect/Poly/PolyTypes.h"
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir {
namespace tutorial {
namespace poly {

#define GEN_PASS_DEF_POLYTOSTANDARD
#include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc"

class PolyToStandardTypeConverter : public TypeConverter {
public:
PolyToStandardTypeConverter(MLIRContext *ctx) {
addConversion([](Type type) { return type; });
addConversion([ctx](PolynomialType type) -> Type {
int degreeBound = type.getDegreeBound();
IntegerType elementTy =
IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Signless);
return RankedTensorType::get({degreeBound}, elementTy);
});

// We don't include any custom materialization hooks because this lowering
// is all done in a single pass. The dialect conversion framework works by
// resolving intermediate (mid-pass) type conflicts by inserting
// unrealized_conversion_cast ops, and only converting those to custom
// materializations if they persist at the end of the pass. In our case,
// we'd only need to use custom materializations if we split this lowering
// across multiple passes.
}
};

struct ConvertAdd : public OpConversionPattern<AddOp> {
ConvertAdd(mlir::MLIRContext *context)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure of before but in MLIR 17+, I've had to instead implement this constructor,

ConvertAdd(TypeConverter& type_converter, MLIRContext* context)
    : OpConversionPattern<AddOp>(type_converter, context) {}

otherwise the call to patterns.add<ConvertAdd>(typeConverter, context); fails at compile time saying no matching constructor for initialization since the constructor we implemented here requires one argument but we provide two.

Also, it maybe just a minor preference thing but struct A : public B is equivalent to struct A : B since structs have public inheritance by default.

: OpConversionPattern<AddOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: implement
return success();
}
};

struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
using PolyToStandardBase::PolyToStandardBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
auto *module = getOperation();

ConversionTarget target(*context);
target.addLegalDialect<arith::ArithDialect>();
target.addIllegalDialect<PolyDialect>();

RewritePatternSet patterns(context);
PolyToStandardTypeConverter typeConverter(context);
patterns.add<ConvertAdd>(typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}
};

} // namespace poly
} // namespace tutorial
} // namespace mlir
24 changes: 24 additions & 0 deletions lib/Conversion/PolyToStandard/PolyToStandard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_
#define LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

// Extra includes needed for dependent dialects
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project

namespace mlir {
namespace tutorial {
namespace poly {

#define GEN_PASS_DECL
#include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc"

} // namespace poly
} // namespace tutorial
} // namespace mlir

#endif // LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_
20 changes: 20 additions & 0 deletions lib/Conversion/PolyToStandard/PolyToStandard.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_
#define LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_

include "mlir/Pass/PassBase.td"

def PolyToStandard : Pass<"poly-to-standard"> {
let summary = "Lower `poly` to standard MLIR dialects.";

let description = [{
This pass lowers the `poly` dialect to standard MLIR, a mixture of affine,
tensor, and arith.
}];
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::tutorial::poly::PolyDialect",
"mlir::tensor::TensorDialect",
];
}

#endif // LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_
1 change: 1 addition & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cc_binary(
srcs = ["tutorial-opt.cpp"],
includes = ["include"],
deps = [
"//lib/Conversion/PolyToStandard",
"//lib/Dialect/Poly",
"//lib/Transform/Affine:Passes",
"//lib/Transform/Arith:Passes",
Expand Down
4 changes: 4 additions & 0 deletions tools/tutorial-opt.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "lib/Conversion/PolyToStandard/PolyToStandard.h"
#include "lib/Dialect/Poly/PolyDialect.h"
#include "lib/Transform/Affine/Passes.h"
#include "lib/Transform/Arith/Passes.h"
Expand All @@ -16,6 +17,9 @@ int main(int argc, char **argv) {
mlir::tutorial::registerAffinePasses();
mlir::tutorial::registerArithPasses();

// Dialect conversion passes
mlir::tutorial::poly::registerPolyToStandardPasses();

return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry));
}