Skip to content

Commit

Permalink
Lower poly.eval
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Oct 20, 2023
1 parent d9e99bb commit b85f7ca
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
53 changes: 51 additions & 2 deletions lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,54 @@ struct ConvertMul : public OpConversionPattern<MulOp> {
}
};

struct ConvertEval : public OpConversionPattern<EvalOp> {
ConvertEval(mlir::MLIRContext *context)
: OpConversionPattern<EvalOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
EvalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto polyTensorType =
cast<RankedTensorType>(adaptor.getPolynomial().getType());
auto numTerms = polyTensorType.getShape()[0];
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto lowerBound =
b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(1));
auto numTermsOp = b.create<arith::ConstantOp>(b.getIndexType(),
b.getIndexAttr(numTerms + 1));
auto step = lowerBound;

auto poly = adaptor.getPolynomial();
auto point = adaptor.getPoint();

// Horner's method:
//
// accum = 0
// for i = 1, 2, ..., N
// accum = point * accum + coeff[N - i]
auto accum =
b.create<arith::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(0));
auto loop = b.create<scf::ForOp>(
lowerBound, numTermsOp, step, accum.getResult(),
[&](OpBuilder &builder, Location loc, Value loopIndex,
ValueRange loopState) {
ImplicitLocOpBuilder b(op.getLoc(), builder);
auto accum = loopState.front();
auto coeffIndex = b.create<arith::SubIOp>(numTermsOp, loopIndex);
auto mulOp = b.create<arith::MulIOp>(point, accum);
auto result = b.create<arith::AddIOp>(
mulOp, b.create<tensor::ExtractOp>(poly, coeffIndex.getResult()));
b.create<scf::YieldOp>(result.getResult());
});

rewriter.replaceOp(op, loop.getResult(0));
return success();
}
};

struct ConvertFromTensor : public OpConversionPattern<FromTensorOp> {
ConvertFromTensor(mlir::MLIRContext *context)
: OpConversionPattern<FromTensorOp>(context) {}
Expand Down Expand Up @@ -211,8 +259,9 @@ struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {

RewritePatternSet patterns(context);
PolyToStandardTypeConverter typeConverter(context);
patterns.add<ConvertAdd, ConvertSub, ConvertMul, ConvertFromTensor,
ConvertToTensor, ConvertConstant>(typeConverter, context);
patterns.add<ConvertAdd, ConvertSub, ConvertEval, ConvertMul,
ConvertFromTensor, ConvertToTensor, ConvertConstant>(
typeConverter, context);

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Poly/PolyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def IntOrComplex : AnyTypeOf<[AnyInteger, AnyComplex]>;

def Poly_EvalOp : Op<Poly_Dialect, "eval", [AllTypesMatch<["point", "output"]>, Has32BitArguments]> {
let summary = "Evaluates a Polynomial at a given input value.";
let arguments = (ins Polynomial:$input, IntOrComplex:$point);
let arguments = (ins Polynomial:$polynomial, IntOrComplex:$point);
let results = (outs IntOrComplex:$output);
let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)";
let assemblyFormat = "$polynomial `,` $point attr-dict `:` `(` qualified(type($polynomial)) `,` type($point) `)` `->` type($output)";
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
Expand Down
20 changes: 20 additions & 0 deletions tests/poly_to_standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,23 @@ func.func @test_lower_mul(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.pol
return %2 : !poly.poly<10>
}


// CHECK-LABEL: test_lower_eval
// CHECK-SAME: (%[[poly:.*]]: [[T:tensor<10xi32>]], %[[point:.*]]: i32) -> i32 {
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[c11:.*]] = arith.constant 11 : index
// CHECK: %[[accum:.*]] = arith.constant 0 : i32
// CHECK: %[[loop:.*]] = scf.for %[[iv:.*]] = %[[c1]] to %[[c11]] step %[[c1]] iter_args(%[[iter_arg:.*]] = %[[accum]]) -> (i32) {
// CHECK: %[[coeffIndex:.*]] = arith.subi %c11, %[[iv]]
// CHECK: %[[mulOp:.*]] = arith.muli %[[point]], %[[iter_arg]]
// CHECK: %[[nextCoeff:.*]] = tensor.extract %[[poly]][%[[coeffIndex]]]
// CHECK: %[[next:.*]] = arith.addi %[[mulOp]], %[[nextCoeff]]
// CHECK: scf.yield %[[next]]
// CHECK: }
// CHECK: return %[[loop]]
// CHECK: }
func.func @test_lower_eval(%0 : !poly.poly<10>, %1 : i32) -> i32 {
%2 = poly.eval %0, %1: (!poly.poly<10>, i32) -> i32
return %2 : i32
}

0 comments on commit b85f7ca

Please sign in to comment.