Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dialect Conversion #20

Merged
merged 15 commits into from
Oct 21, 2023
17 changes: 17 additions & 0 deletions lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ class PolyToStandardTypeConverter : public TypeConverter {
}
};

struct ConvertAdd : public OpConversionPattern<AddOp> {
ConvertAdd(mlir::MLIRContext *context)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure of before but in MLIR 17+, I've had to instead implement this constructor,

ConvertAdd(TypeConverter& type_converter, MLIRContext* context)
    : OpConversionPattern<AddOp>(type_converter, context) {}

otherwise the call to patterns.add<ConvertAdd>(typeConverter, context); fails at compile time saying no matching constructor for initialization since the constructor we implemented here requires one argument but we provide two.

Also, it maybe just a minor preference thing but struct A : public B is equivalent to struct A : B since structs have public inheritance by default.

: OpConversionPattern<AddOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO: implement
return success();
}
};

struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
using PolyToStandardBase::PolyToStandardBase;

Expand All @@ -40,9 +54,12 @@ struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
auto *module = getOperation();

ConversionTarget target(*context);
target.addLegalDialect<arith::ArithDialect>();
target.addIllegalDialect<PolyDialect>();

RewritePatternSet patterns(context);
PolyToStandardTypeConverter typeConverter(context);
patterns.add<ConvertAdd>(typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
Expand Down