Skip to content

Commit

Permalink
fix the off-by-one error in the lowering of eval
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Oct 31, 2023
1 parent bc47670 commit 8422826
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
4 changes: 3 additions & 1 deletion lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ struct ConvertEval : public OpConversionPattern<EvalOp> {
auto lowerBound =
b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(1));
auto numTermsOp = b.create<arith::ConstantOp>(b.getIndexType(),
b.getIndexAttr(numTerms));
auto upperBound = b.create<arith::ConstantOp>(b.getIndexType(),
b.getIndexAttr(numTerms + 1));
auto step = lowerBound;

Expand All @@ -163,7 +165,7 @@ struct ConvertEval : public OpConversionPattern<EvalOp> {
auto accum =
b.create<arith::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(0));
auto loop = b.create<scf::ForOp>(
lowerBound, numTermsOp, step, accum.getResult(),
lowerBound, upperBound, step, accum.getResult(),
[&](OpBuilder &builder, Location loc, Value loopIndex,
ValueRange loopState) {
ImplicitLocOpBuilder b(op.getLoc(), builder);
Expand Down
3 changes: 2 additions & 1 deletion tests/poly_to_standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ func.func @test_lower_mul(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.pol
// CHECK-LABEL: test_lower_eval
// CHECK-SAME: (%[[poly:.*]]: [[T:tensor<10xi32>]], %[[point:.*]]: i32) -> i32 {
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[c10:.*]] = arith.constant 10 : index
// CHECK: %[[c11:.*]] = arith.constant 11 : index
// CHECK: %[[accum:.*]] = arith.constant 0 : i32
// CHECK: %[[loop:.*]] = scf.for %[[iv:.*]] = %[[c1]] to %[[c11]] step %[[c1]] iter_args(%[[iter_arg:.*]] = %[[accum]]) -> (i32) {
// CHECK: %[[coeffIndex:.*]] = arith.subi %c11, %[[iv]]
// CHECK: %[[coeffIndex:.*]] = arith.subi %[[c10]], %[[iv]]
// CHECK: %[[mulOp:.*]] = arith.muli %[[point]], %[[iter_arg]]
// CHECK: %[[nextCoeff:.*]] = tensor.extract %[[poly]][%[[coeffIndex]]]
// CHECK: %[[next:.*]] = arith.addi %[[mulOp]], %[[nextCoeff]]
Expand Down

0 comments on commit 8422826

Please sign in to comment.