From 1f5dfdd5c18949d3f717d7705aa97bdf3c3b67e1 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 20 Oct 2023 15:13:27 -0700 Subject: [PATCH] demonstrate what a materialization would involve --- .../PolyToStandard/PolyToStandard.cpp | 43 +++++++------------ 1 file changed, 16 insertions(+), 27 deletions(-) diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.cpp b/lib/Conversion/PolyToStandard/PolyToStandard.cpp index 7884bdb..eb0170e 100644 --- a/lib/Conversion/PolyToStandard/PolyToStandard.cpp +++ b/lib/Conversion/PolyToStandard/PolyToStandard.cpp @@ -27,13 +27,17 @@ class PolyToStandardTypeConverter : public TypeConverter { return RankedTensorType::get({degreeBound}, elementTy); }); - // We don't include any custom materialization hooks because this lowering - // is all done in a single pass. The dialect conversion framework works by - // resolving intermediate (mid-pass) type conflicts by inserting - // unrealized_conversion_cast ops, and only converting those to custom - // materializations if they persist at the end of the pass. In our case, - // we'd only need to use custom materializations if we split this lowering - // across multiple passes. + // Convert from a tensor type to a poly type: use from_tensor + addSourceMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, type, inputs[0]); + }); + + // Convert from a poly type to a tensor type: use to_tensor + addTargetMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + return builder.create(loc, type, inputs[0]); + }); } }; @@ -53,22 +57,6 @@ struct ConvertAdd : public OpConversionPattern { } }; -struct ConvertSub : public OpConversionPattern { - ConvertSub(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - SubOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - arith::SubIOp subOp = rewriter.create( - op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); - rewriter.replaceOp(op.getOperation(), {subOp}); - return success(); - } -}; - struct ConvertMul : public OpConversionPattern { ConvertMul(mlir::MLIRContext *context) : OpConversionPattern(context) {} @@ -257,13 +245,14 @@ struct PolyToStandard : impl::PolyToStandardBase { ConversionTarget target(*context); target.addLegalDialect(); - target.addIllegalDialect(); + target.addIllegalOp(); + // target.addIllegalDialect(); RewritePatternSet patterns(context); PolyToStandardTypeConverter typeConverter(context); - patterns.add(typeConverter, - context); + patterns.add(typeConverter, context); populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter);