diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.cpp b/lib/Conversion/PolyToStandard/PolyToStandard.cpp index eb0170e..7884bdb 100644 --- a/lib/Conversion/PolyToStandard/PolyToStandard.cpp +++ b/lib/Conversion/PolyToStandard/PolyToStandard.cpp @@ -27,17 +27,13 @@ class PolyToStandardTypeConverter : public TypeConverter { return RankedTensorType::get({degreeBound}, elementTy); }); - // Convert from a tensor type to a poly type: use from_tensor - addSourceMaterialization([](OpBuilder &builder, Type type, - ValueRange inputs, Location loc) -> Value { - return builder.create(loc, type, inputs[0]); - }); - - // Convert from a poly type to a tensor type: use to_tensor - addTargetMaterialization([](OpBuilder &builder, Type type, - ValueRange inputs, Location loc) -> Value { - return builder.create(loc, type, inputs[0]); - }); + // We don't include any custom materialization hooks because this lowering + // is all done in a single pass. The dialect conversion framework works by + // resolving intermediate (mid-pass) type conflicts by inserting + // unrealized_conversion_cast ops, and only converting those to custom + // materializations if they persist at the end of the pass. In our case, + // we'd only need to use custom materializations if we split this lowering + // across multiple passes. } }; @@ -57,6 +53,22 @@ struct ConvertAdd : public OpConversionPattern { } }; +struct ConvertSub : public OpConversionPattern { + ConvertSub(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + arith::SubIOp subOp = rewriter.create( + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op.getOperation(), {subOp}); + return success(); + } +}; + struct ConvertMul : public OpConversionPattern { ConvertMul(mlir::MLIRContext *context) : OpConversionPattern(context) {} @@ -245,14 +257,13 @@ struct PolyToStandard : impl::PolyToStandardBase { ConversionTarget target(*context); target.addLegalDialect(); - target.addIllegalOp(); - // target.addIllegalDialect(); + target.addIllegalDialect(); RewritePatternSet patterns(context); PolyToStandardTypeConverter typeConverter(context); - patterns.add(typeConverter, context); + patterns.add(typeConverter, + context); populateFunctionOpInterfaceTypeConversionPattern( patterns, typeConverter);