From 6c40e28c379d6fd2648f3b94291b65113b72608b Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 20 Oct 2023 15:14:51 -0700 Subject: [PATCH] Revert "demonstrate what a materialization would involve" This reverts commit 1f5dfdd5c18949d3f717d7705aa97bdf3c3b67e1. --- .../PolyToStandard/PolyToStandard.cpp | 43 ++++++++++++------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.cpp b/lib/Conversion/PolyToStandard/PolyToStandard.cpp index eb0170e..7884bdb 100644 --- a/lib/Conversion/PolyToStandard/PolyToStandard.cpp +++ b/lib/Conversion/PolyToStandard/PolyToStandard.cpp @@ -27,17 +27,13 @@ class PolyToStandardTypeConverter : public TypeConverter { return RankedTensorType::get({degreeBound}, elementTy); }); - // Convert from a tensor type to a poly type: use from_tensor - addSourceMaterialization([](OpBuilder &builder, Type type, - ValueRange inputs, Location loc) -> Value { - return builder.create(loc, type, inputs[0]); - }); - - // Convert from a poly type to a tensor type: use to_tensor - addTargetMaterialization([](OpBuilder &builder, Type type, - ValueRange inputs, Location loc) -> Value { - return builder.create(loc, type, inputs[0]); - }); + // We don't include any custom materialization hooks because this lowering + // is all done in a single pass. The dialect conversion framework works by + // resolving intermediate (mid-pass) type conflicts by inserting + // unrealized_conversion_cast ops, and only converting those to custom + // materializations if they persist at the end of the pass. In our case, + // we'd only need to use custom materializations if we split this lowering + // across multiple passes. } }; @@ -57,6 +53,22 @@ struct ConvertAdd : public OpConversionPattern { } }; +struct ConvertSub : public OpConversionPattern { + ConvertSub(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + arith::SubIOp subOp = rewriter.create( + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op.getOperation(), {subOp}); + return success(); + } +}; + struct ConvertMul : public OpConversionPattern { ConvertMul(mlir::MLIRContext *context) : OpConversionPattern(context) {} @@ -245,14 +257,13 @@ struct PolyToStandard : impl::PolyToStandardBase { ConversionTarget target(*context); target.addLegalDialect(); - target.addIllegalOp(); - // target.addIllegalDialect(); + target.addIllegalDialect(); RewritePatternSet patterns(context); PolyToStandardTypeConverter typeConverter(context); - patterns.add(typeConverter, context); + patterns.add(typeConverter, + context); populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter);