Skip to content

Commit

Permalink
lower poly.constant
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Oct 19, 2023
1 parent 282dcb1 commit 8eea0d1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
18 changes: 17 additions & 1 deletion lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
}
};

struct ConvertConstant : public OpConversionPattern<ConstantOp> {
ConvertConstant(mlir::MLIRContext *context)
: OpConversionPattern<ConstantOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
arith::ConstantOp constOp = rewriter.create<arith::ConstantOp>(
op.getLoc(), adaptor.getCoefficients());
rewriter.replaceOp(op.getOperation(), {constOp});
return success();
}
};

struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
using PolyToStandardBase::PolyToStandardBase;

Expand All @@ -63,7 +79,7 @@ struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {

RewritePatternSet patterns(context);
PolyToStandardTypeConverter typeConverter(context);
patterns.add<ConvertAdd>(typeConverter, context);
patterns.add<ConvertAdd, ConvertConstant>(typeConverter, context);

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
Expand Down
12 changes: 12 additions & 0 deletions tests/poly_to_standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,15 @@ func.func @test_lower_add(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.pol
%2 = poly.add %0, %1: !poly.poly<10>
return %2 : !poly.poly<10>
}

// CHECK-LABEL: test_lower_add_and_fold
func.func @test_lower_add_and_fold() {
// CHECK: arith.constant dense<[2, 3, 4]> : tensor<3xi32>
%0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
// CHECK: arith.constant dense<[3, 4, 5]> : tensor<3xi32>
%1 = poly.constant dense<[3, 4, 5]> : tensor<3xi32> : !poly.poly<10>
// would be an addi, but it was folded
// CHECK: arith.constant
%2 = poly.add %0, %1: !poly.poly<10>
return
}

0 comments on commit 8eea0d1

Please sign in to comment.