Skip to content

Commit

Permalink
demonstrate what a materialization would involve
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Oct 20, 2023
1 parent 3570914 commit 1f5dfdd
Showing 1 changed file with 16 additions and 27 deletions.
43 changes: 16 additions & 27 deletions lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ class PolyToStandardTypeConverter : public TypeConverter {
return RankedTensorType::get({degreeBound}, elementTy);
});

// We don't include any custom materialization hooks because this lowering
// is all done in a single pass. The dialect conversion framework works by
// resolving intermediate (mid-pass) type conflicts by inserting
// unrealized_conversion_cast ops, and only converting those to custom
// materializations if they persist at the end of the pass. In our case,
// we'd only need to use custom materializations if we split this lowering
// across multiple passes.
// Convert from a tensor type to a poly type: use from_tensor
addSourceMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
return builder.create<poly::FromTensorOp>(loc, type, inputs[0]);
});

// Convert from a poly type to a tensor type: use to_tensor
addTargetMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) -> Value {
return builder.create<poly::ToTensorOp>(loc, type, inputs[0]);
});
}
};

Expand All @@ -53,22 +57,6 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
}
};

struct ConvertSub : public OpConversionPattern<SubOp> {
ConvertSub(mlir::MLIRContext *context)
: OpConversionPattern<SubOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
SubOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
arith::SubIOp subOp = rewriter.create<arith::SubIOp>(
op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op.getOperation(), {subOp});
return success();
}
};

struct ConvertMul : public OpConversionPattern<MulOp> {
ConvertMul(mlir::MLIRContext *context)
: OpConversionPattern<MulOp>(context) {}
Expand Down Expand Up @@ -257,13 +245,14 @@ struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {

ConversionTarget target(*context);
target.addLegalDialect<arith::ArithDialect>();
target.addIllegalDialect<PolyDialect>();
target.addIllegalOp<AddOp, MulOp, EvalOp, ConstantOp, EvalOp, FromTensorOp,
ToTensorOp>();
// target.addIllegalDialect<PolyDialect>();

RewritePatternSet patterns(context);
PolyToStandardTypeConverter typeConverter(context);
patterns.add<ConvertAdd, ConvertConstant, ConvertSub, ConvertEval,
ConvertMul, ConvertFromTensor, ConvertToTensor>(typeConverter,
context);
patterns.add<ConvertAdd, ConvertConstant, ConvertEval, ConvertMul,
ConvertFromTensor, ConvertToTensor>(typeConverter, context);

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
Expand Down

0 comments on commit 1f5dfdd

Please sign in to comment.