Skip to content

Commit

Permalink
improve constant lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Oct 20, 2023
1 parent 9067521 commit 6176719
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 5 additions & 3 deletions lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,11 @@ struct ConvertConstant : public OpConversionPattern<ConstantOp> {
LogicalResult matchAndRewrite(
ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
arith::ConstantOp constOp = rewriter.create<arith::ConstantOp>(
op.getLoc(), adaptor.getCoefficients());
rewriter.replaceOp(op.getOperation(), {constOp});
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto constOp = b.create<arith::ConstantOp>(adaptor.getCoefficients());
auto fromTensorOp =
b.create<FromTensorOp>(op.getResult().getType(), constOp);
rewriter.replaceOp(op, fromTensorOp.getResult());
return success();
}
};
Expand Down
11 changes: 11 additions & 0 deletions tests/poly_to_standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,14 @@ func.func @test_lower_eval(%0 : !poly.poly<10>, %1 : i32) -> i32 {
return %2 : i32
}


// CHECK-LABEL: test_lower_many
// CHECK-NOT: poly
func.func @test_lower_many(%arg : !poly.poly<10>, %point : i32) -> i32 {
%0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%1 = poly.add %0, %arg : !poly.poly<10>
%2 = poly.mul %1, %1 : !poly.poly<10>
%3 = poly.sub %2, %arg : !poly.poly<10>
%4 = poly.eval %3, %point: (!poly.poly<10>, i32) -> i32
return %4 : i32
}

0 comments on commit 6176719

Please sign in to comment.