From 617671907f4e5231ac0a45332503117893fb331c Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 20 Oct 2023 14:17:55 -0700 Subject: [PATCH] improve constant lowering --- lib/Conversion/PolyToStandard/PolyToStandard.cpp | 8 +++++--- tests/poly_to_standard.mlir | 11 +++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.cpp b/lib/Conversion/PolyToStandard/PolyToStandard.cpp index 440addb..40a2b54 100644 --- a/lib/Conversion/PolyToStandard/PolyToStandard.cpp +++ b/lib/Conversion/PolyToStandard/PolyToStandard.cpp @@ -239,9 +239,11 @@ struct ConvertConstant : public OpConversionPattern { LogicalResult matchAndRewrite( ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - arith::ConstantOp constOp = rewriter.create( - op.getLoc(), adaptor.getCoefficients()); - rewriter.replaceOp(op.getOperation(), {constOp}); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto constOp = b.create(adaptor.getCoefficients()); + auto fromTensorOp = + b.create(op.getResult().getType(), constOp); + rewriter.replaceOp(op, fromTensorOp.getResult()); return success(); } }; diff --git a/tests/poly_to_standard.mlir b/tests/poly_to_standard.mlir index 759bcb7..bab4e6b 100644 --- a/tests/poly_to_standard.mlir +++ b/tests/poly_to_standard.mlir @@ -98,3 +98,14 @@ func.func @test_lower_eval(%0 : !poly.poly<10>, %1 : i32) -> i32 { return %2 : i32 } + +// CHECK-LABEL: test_lower_many +// CHECK-NOT: poly +func.func @test_lower_many(%arg : !poly.poly<10>, %point : i32) -> i32 { + %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> + %1 = poly.add %0, %arg : !poly.poly<10> + %2 = poly.mul %1, %1 : !poly.poly<10> + %3 = poly.sub %2, %arg : !poly.poly<10> + %4 = poly.eval %3, %point: (!poly.poly<10>, i32) -> i32 + return %4 : i32 +}