diff --git a/.gitignore b/.gitignore index 5756b0b..36802d6 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,4 @@ externals /CMakeSettings.json # Compilation databases compile_commands.json -tablegen_compile_commands.yml \ No newline at end of file +tablegen_compile_commands.yml diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index f66a974..714031d 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Dialect) -add_subdirectory(Transform) \ No newline at end of file +add_subdirectory(Conversion) +add_subdirectory(Transform) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt new file mode 100644 index 0000000..fb91c06 --- /dev/null +++ b/lib/Conversion/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(PolyToStandard) diff --git a/lib/Conversion/PolyToStandard/BUILD b/lib/Conversion/PolyToStandard/BUILD new file mode 100644 index 0000000..c51d951 --- /dev/null +++ b/lib/Conversion/PolyToStandard/BUILD @@ -0,0 +1,42 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_visibility = ["//visibility:public"], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=PolyToStandard", + ], + "PolyToStandard.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "PolyToStandard.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +cc_library( + name = "PolyToStandard", + srcs = ["PolyToStandard.cpp"], + hdrs = ["PolyToStandard.h"], + deps = [ + "pass_inc_gen", + "//lib/Dialect/Poly", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Conversion/PolyToStandard/CMakeLists.txt b/lib/Conversion/PolyToStandard/CMakeLists.txt new file mode 100644 index 0000000..96dde92 --- /dev/null +++ b/lib/Conversion/PolyToStandard/CMakeLists.txt @@ -0,0 +1,29 @@ +add_mlir_library(PolyToStandard + PolyToStandard.cpp + + ${PROJECT_SOURCE_DIR}/lib/Conversion/PolyToStandard/ + ADDITIONAL_HEADER_DIRS + + DEPENDS + PolyToStandardPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRPoly + MLIRArithDialect + MLIRFuncDialect + MLIRFuncTransforms + MLIRIR + MLIRPass + MLIRSCFDialect + MLIRTensorDialect + MLIRTransforms + ) + +set(LLVM_TARGET_DEFINITIONS PolyToStandard.td) +mlir_tablegen(PolyToStandard.h.inc -gen-pass-decls -name PolyToStandard) +add_dependencies(mlir-headers MLIRPolyOpsIncGen) +add_public_tablegen_target(PolyToStandardPassIncGen) +add_mlir_doc(PolyToStandard PolyToStandard PolyToStandard/ -gen-pass-doc) diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.cpp b/lib/Conversion/PolyToStandard/PolyToStandard.cpp new file mode 100644 index 0000000..7884bdb --- /dev/null +++ b/lib/Conversion/PolyToStandard/PolyToStandard.cpp @@ -0,0 +1,299 @@ +#include "lib/Conversion/PolyToStandard/PolyToStandard.h" + +#include "lib/Dialect/Poly/PolyOps.h" +#include "lib/Dialect/Poly/PolyTypes.h" +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace tutorial { +namespace poly { + +#define GEN_PASS_DEF_POLYTOSTANDARD +#include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc" + +class PolyToStandardTypeConverter : public TypeConverter { + public: + PolyToStandardTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + addConversion([ctx](PolynomialType type) -> Type { + int degreeBound = type.getDegreeBound(); + IntegerType elementTy = + IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Signless); + return RankedTensorType::get({degreeBound}, elementTy); + }); + + // We don't include any custom materialization hooks because this lowering + // is all done in a single pass. The dialect conversion framework works by + // resolving intermediate (mid-pass) type conflicts by inserting + // unrealized_conversion_cast ops, and only converting those to custom + // materializations if they persist at the end of the pass. In our case, + // we'd only need to use custom materializations if we split this lowering + // across multiple passes. + } +}; + +struct ConvertAdd : public OpConversionPattern { + ConvertAdd(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + AddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + arith::AddIOp addOp = rewriter.create( + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op.getOperation(), {addOp}); + return success(); + } +}; + +struct ConvertSub : public OpConversionPattern { + ConvertSub(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SubOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + arith::SubIOp subOp = rewriter.create( + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); + rewriter.replaceOp(op.getOperation(), {subOp}); + return success(); + } +}; + +struct ConvertMul : public OpConversionPattern { + ConvertMul(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + MulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto polymulTensorType = cast(adaptor.getLhs().getType()); + auto numTerms = polymulTensorType.getShape()[0]; + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + // Create an all-zeros tensor to store the result + auto polymulResult = b.create( + polymulTensorType, DenseElementsAttr::get(polymulTensorType, 0)); + + // Loop bounds and step. + auto lowerBound = + b.create(b.getIndexType(), b.getIndexAttr(0)); + auto numTermsOp = + b.create(b.getIndexType(), b.getIndexAttr(numTerms)); + auto step = + b.create(b.getIndexType(), b.getIndexAttr(1)); + + auto p0 = adaptor.getLhs(); + auto p1 = adaptor.getRhs(); + + // for i = 0, ..., N-1 + // for j = 0, ..., N-1 + // product[i+j (mod N)] += p0[i] * p1[j] + auto outerLoop = b.create( + lowerBound, numTermsOp, step, ValueRange(polymulResult.getResult()), + [&](OpBuilder &builder, Location loc, Value p0Index, + ValueRange loopState) { + ImplicitLocOpBuilder b(op.getLoc(), builder); + auto innerLoop = b.create( + lowerBound, numTermsOp, step, loopState, + [&](OpBuilder &builder, Location loc, Value p1Index, + ValueRange loopState) { + ImplicitLocOpBuilder b(op.getLoc(), builder); + auto accumTensor = loopState.front(); + auto destIndex = b.create( + b.create(p0Index, p1Index), numTermsOp); + auto mulOp = b.create( + b.create(p0, ValueRange(p0Index)), + b.create(p1, ValueRange(p1Index))); + auto result = b.create( + mulOp, b.create(accumTensor, + destIndex.getResult())); + auto stored = b.create(result, accumTensor, + destIndex.getResult()); + b.create(stored.getResult()); + }); + + b.create(innerLoop.getResults()); + }); + + rewriter.replaceOp(op, outerLoop.getResult(0)); + return success(); + } +}; + +struct ConvertEval : public OpConversionPattern { + ConvertEval(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + EvalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto polyTensorType = + cast(adaptor.getPolynomial().getType()); + auto numTerms = polyTensorType.getShape()[0]; + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto lowerBound = + b.create(b.getIndexType(), b.getIndexAttr(1)); + auto numTermsOp = b.create(b.getIndexType(), + b.getIndexAttr(numTerms + 1)); + auto step = lowerBound; + + auto poly = adaptor.getPolynomial(); + auto point = adaptor.getPoint(); + + // Horner's method: + // + // accum = 0 + // for i = 1, 2, ..., N + // accum = point * accum + coeff[N - i] + auto accum = + b.create(b.getI32Type(), b.getI32IntegerAttr(0)); + auto loop = b.create( + lowerBound, numTermsOp, step, accum.getResult(), + [&](OpBuilder &builder, Location loc, Value loopIndex, + ValueRange loopState) { + ImplicitLocOpBuilder b(op.getLoc(), builder); + auto accum = loopState.front(); + auto coeffIndex = b.create(numTermsOp, loopIndex); + auto mulOp = b.create(point, accum); + auto result = b.create( + mulOp, b.create(poly, coeffIndex.getResult())); + b.create(result.getResult()); + }); + + rewriter.replaceOp(op, loop.getResult(0)); + return success(); + } +}; + +struct ConvertFromTensor : public OpConversionPattern { + ConvertFromTensor(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + FromTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resultTensorTy = cast( + typeConverter->convertType(op->getResultTypes()[0])); + auto resultShape = resultTensorTy.getShape()[0]; + auto resultEltTy = resultTensorTy.getElementType(); + + auto inputTensorTy = op.getInput().getType(); + auto inputShape = inputTensorTy.getShape()[0]; + + // Zero pad the tensor if the coefficients' size is less than the polynomial + // degree. + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto coeffValue = adaptor.getInput(); + if (inputShape < resultShape) { + SmallVector low, high; + low.push_back(rewriter.getIndexAttr(0)); + high.push_back(rewriter.getIndexAttr(resultShape - inputShape)); + coeffValue = b.create( + resultTensorTy, coeffValue, low, high, + b.create(rewriter.getIntegerAttr(resultEltTy, 0)), + /*nofold=*/false); + } + + rewriter.replaceOp(op, coeffValue); + return success(); + } +}; + +struct ConvertToTensor : public OpConversionPattern { + ConvertToTensor(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ToTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, adaptor.getInput()); + return success(); + } +}; + +struct ConvertConstant : public OpConversionPattern { + ConvertConstant(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto constOp = b.create(adaptor.getCoefficients()); + auto fromTensorOp = + b.create(op.getResult().getType(), constOp); + rewriter.replaceOp(op, fromTensorOp.getResult()); + return success(); + } +}; + +struct PolyToStandard : impl::PolyToStandardBase { + using PolyToStandardBase::PolyToStandardBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(context); + PolyToStandardTypeConverter typeConverter(context); + patterns.add(typeConverter, + context); + + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + + populateReturnOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); + + populateCallOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter); + }); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace poly +} // namespace tutorial +} // namespace mlir diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.h b/lib/Conversion/PolyToStandard/PolyToStandard.h new file mode 100644 index 0000000..ca7dfc9 --- /dev/null +++ b/lib/Conversion/PolyToStandard/PolyToStandard.h @@ -0,0 +1,24 @@ +#ifndef LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_ +#define LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +// Extra includes needed for dependent dialects +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project + +namespace mlir { +namespace tutorial { +namespace poly { + +#define GEN_PASS_DECL +#include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc" + +} // namespace poly +} // namespace tutorial +} // namespace mlir + +#endif // LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_H_ diff --git a/lib/Conversion/PolyToStandard/PolyToStandard.td b/lib/Conversion/PolyToStandard/PolyToStandard.td new file mode 100644 index 0000000..31e9d9b --- /dev/null +++ b/lib/Conversion/PolyToStandard/PolyToStandard.td @@ -0,0 +1,21 @@ +#ifndef LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_ +#define LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_ + +include "mlir/Pass/PassBase.td" + +def PolyToStandard : Pass<"poly-to-standard"> { + let summary = "Lower `poly` to standard MLIR dialects."; + + let description = [{ + This pass lowers the `poly` dialect to standard MLIR, a mixture of affine, + tensor, and arith. + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::tutorial::poly::PolyDialect", + "mlir::tensor::TensorDialect", + "mlir::scf::SCFDialect", + ]; +} + +#endif // LIB_CONVERSION_POLYTOSTANDARD_POLYTOSTANDARD_TD_ diff --git a/lib/Dialect/Poly/CMakeLists.txt b/lib/Dialect/Poly/CMakeLists.txt index 628253c..f001992 100644 --- a/lib/Dialect/Poly/CMakeLists.txt +++ b/lib/Dialect/Poly/CMakeLists.txt @@ -24,4 +24,4 @@ add_mlir_dialect_library(MLIRPoly ${PROJECT_SOURCE_DIR}/lib/Dialect/Poly LINK_LIBS PUBLIC - ) \ No newline at end of file + ) diff --git a/lib/Dialect/Poly/PolyOps.td b/lib/Dialect/Poly/PolyOps.td index 2106bb9..76cb2f9 100644 --- a/lib/Dialect/Poly/PolyOps.td +++ b/lib/Dialect/Poly/PolyOps.td @@ -45,13 +45,20 @@ def Poly_FromTensorOp : Op { let hasFolder = 1; } +def Poly_ToTensorOp : Op { + let summary = "Converts a polynomial to a tensor of its integer coefficients."; + let arguments = (ins Polynomial:$input); + let results = (outs TensorOf<[AnyInteger]>:$output); + let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; +} + def IntOrComplex : AnyTypeOf<[AnyInteger, AnyComplex]>; def Poly_EvalOp : Op, Has32BitArguments]> { let summary = "Evaluates a Polynomial at a given input value."; - let arguments = (ins Polynomial:$input, IntOrComplex:$point); + let arguments = (ins Polynomial:$polynomial, IntOrComplex:$point); let results = (outs IntOrComplex:$output); - let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)"; + let assemblyFormat = "$polynomial `,` $point attr-dict `:` `(` qualified(type($polynomial)) `,` type($point) `)` `->` type($output)"; let hasVerifier = 1; let hasCanonicalizer = 1; } diff --git a/lib/Transform/Affine/CMakeLists.txt b/lib/Transform/Affine/CMakeLists.txt index dd3340b..eaa2b68 100644 --- a/lib/Transform/Affine/CMakeLists.txt +++ b/lib/Transform/Affine/CMakeLists.txt @@ -10,4 +10,4 @@ add_mlir_library(AffineFullUnroll set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Affine) add_public_tablegen_target(MLIRAffineFullUnrollPasses) -add_mlir_doc(Passes AffinePasses ./ -gen-pass-doc) \ No newline at end of file +add_mlir_doc(Passes AffinePasses ./ -gen-pass-doc) diff --git a/tests/poly_syntax.mlir b/tests/poly_syntax.mlir index f3c7f5e..4b2e39b 100644 --- a/tests/poly_syntax.mlir +++ b/tests/poly_syntax.mlir @@ -39,6 +39,9 @@ module { %12 = poly.constant dense<"0x020304"> : tensor<3xi8> : !poly.poly<10> %13 = poly.constant dense<4> : tensor<100xi32> : !poly.poly<10> + // CHECK: poly.to_tensor + %14 = poly.to_tensor %1 : !poly.poly<10> -> tensor<10xi32> + return %4 : !poly.poly<10> } } diff --git a/tests/poly_to_standard.mlir b/tests/poly_to_standard.mlir new file mode 100644 index 0000000..ffeacb3 --- /dev/null +++ b/tests/poly_to_standard.mlir @@ -0,0 +1,111 @@ +// RUN: tutorial-opt --poly-to-standard %s | FileCheck %s + +// CHECK-LABEL: test_lower_add +func.func @test_lower_add(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> { + // CHECK: arith.addi + %2 = poly.add %0, %1: !poly.poly<10> + return %2 : !poly.poly<10> +} + +// CHECK-LABEL: test_lower_sub +func.func @test_lower_sub(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> { + // CHECK: arith.subi + %2 = poly.sub %0, %1: !poly.poly<10> + return %2 : !poly.poly<10> +} + +// CHECK-LABEL: test_lower_to_tensor( +// CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T]] { +// CHECK-NEXT: return %[[V0]] : [[T]] +func.func @test_lower_to_tensor(%0: !poly.poly<10>) -> tensor<10xi32> { + %2 = poly.to_tensor %0: !poly.poly<10> -> tensor<10xi32> + return %2 : tensor<10xi32> +} + +// CHECK-LABEL: test_lower_from_tensor( +// CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T]] { +// CHECK-NEXT: return %[[V0]] : [[T]] +func.func @test_lower_from_tensor(%0 : tensor<10xi32>) -> !poly.poly<10> { + %2 = poly.from_tensor %0: tensor<10xi32> -> !poly.poly<10> + return %2 : !poly.poly<10> +} + +// CHECK-LABEL: test_lower_from_tensor_extend( +// CHECK-SAME: %[[V0:.*]]: [[T:tensor<10xi32>]]) -> [[T2:tensor<20xi32>]] { +// CHECK: %[[V1:.*]] = tensor.pad %[[V0]] low[0] high[10] +// CHECK: return %[[V1]] : [[T2]] +func.func @test_lower_from_tensor_extend(%0 : tensor<10xi32>) -> !poly.poly<20> { + %2 = poly.from_tensor %0: tensor<10xi32> -> !poly.poly<20> + return %2 : !poly.poly<20> +} + +// CHECK-LABEL: test_lower_add_and_fold +func.func @test_lower_add_and_fold() { + // CHECK: arith.constant dense<[2, 3, 4]> : tensor<3xi32> + %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> + // CHECK: arith.constant dense<[3, 4, 5]> : tensor<3xi32> + %1 = poly.constant dense<[3, 4, 5]> : tensor<3xi32> : !poly.poly<10> + // would be an addi, but it was folded + // CHECK: arith.constant + %2 = poly.add %0, %1: !poly.poly<10> + return +} + +// CHECK-LABEL: test_lower_mul +// CHECK-SAME: (%[[p0:.*]]: [[T:tensor<10xi32>]], %[[p1:.*]]: [[T]]) -> [[T]] { +// CHECK: %[[cst:.*]] = arith.constant dense<0> : [[T]] +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[c10:.*]] = arith.constant 10 : index +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[outer:.*]] = scf.for %[[outer_iv:.*]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[outer_iter_arg:.*]] = %[[cst]]) -> ([[T]]) { +// CHECK: %[[inner:.*]] = scf.for %[[inner_iv:.*]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[inner_iter_arg:.*]] = %[[outer_iter_arg]]) -> ([[T]]) { +// CHECK: %[[index_sum:.*]] = arith.addi %arg2, %arg4 +// CHECK: %[[dest_index:.*]] = arith.remui %[[index_sum]], %[[c10]] +// CHECK-DAG: %[[p0_extracted:.*]] = tensor.extract %[[p0]][%[[outer_iv]]] +// CHECK-DAG: %[[p1_extracted:.*]] = tensor.extract %[[p1]][%[[inner_iv]]] +// CHECK: %[[coeff_mul:.*]] = arith.muli %[[p0_extracted]], %[[p1_extracted]] +// CHECK: %[[accum:.*]] = tensor.extract %[[inner_iter_arg]][%[[dest_index]]] +// CHECK: %[[to_insert:.*]] = arith.addi %[[coeff_mul]], %[[accum]] +// CHECK: %[[inserted:.*]] = tensor.insert %[[to_insert]] into %[[inner_iter_arg]][%[[dest_index]]] +// CHECK: scf.yield %[[inserted]] +// CHECK: } +// CHECK: scf.yield %[[inner]] +// CHECK: } +// CHECK: return %[[outer]] +// CHECK: } +func.func @test_lower_mul(%0 : !poly.poly<10>, %1 : !poly.poly<10>) -> !poly.poly<10> { + %2 = poly.mul %0, %1: !poly.poly<10> + return %2 : !poly.poly<10> +} + + +// CHECK-LABEL: test_lower_eval +// CHECK-SAME: (%[[poly:.*]]: [[T:tensor<10xi32>]], %[[point:.*]]: i32) -> i32 { +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[c11:.*]] = arith.constant 11 : index +// CHECK: %[[accum:.*]] = arith.constant 0 : i32 +// CHECK: %[[loop:.*]] = scf.for %[[iv:.*]] = %[[c1]] to %[[c11]] step %[[c1]] iter_args(%[[iter_arg:.*]] = %[[accum]]) -> (i32) { +// CHECK: %[[coeffIndex:.*]] = arith.subi %c11, %[[iv]] +// CHECK: %[[mulOp:.*]] = arith.muli %[[point]], %[[iter_arg]] +// CHECK: %[[nextCoeff:.*]] = tensor.extract %[[poly]][%[[coeffIndex]]] +// CHECK: %[[next:.*]] = arith.addi %[[mulOp]], %[[nextCoeff]] +// CHECK: scf.yield %[[next]] +// CHECK: } +// CHECK: return %[[loop]] +// CHECK: } +func.func @test_lower_eval(%0 : !poly.poly<10>, %1 : i32) -> i32 { + %2 = poly.eval %0, %1: (!poly.poly<10>, i32) -> i32 + return %2 : i32 +} + + +// CHECK-LABEL: test_lower_many +// CHECK-NOT: poly +func.func @test_lower_many(%arg : !poly.poly<10>, %point : i32) -> i32 { + %0 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> + %1 = poly.add %0, %arg : !poly.poly<10> + %2 = poly.mul %1, %1 : !poly.poly<10> + %3 = poly.sub %2, %arg : !poly.poly<10> + %4 = poly.eval %3, %point: (!poly.poly<10>, i32) -> i32 + return %4 : i32 +} diff --git a/tools/BUILD b/tools/BUILD index 987e90c..a9df186 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -11,6 +11,7 @@ cc_binary( srcs = ["tutorial-opt.cpp"], includes = ["include"], deps = [ + "//lib/Conversion/PolyToStandard", "//lib/Dialect/Poly", "//lib/Transform/Affine:Passes", "//lib/Transform/Arith:Passes", diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 930bb80..59f33b1 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -9,6 +9,7 @@ set (LIBS MulToAdd MLIROptLib MLIRPass + PolyToStandard ) add_llvm_executable(tutorial-opt tutorial-opt.cpp) @@ -16,4 +17,4 @@ add_llvm_executable(tutorial-opt tutorial-opt.cpp) llvm_update_compile_flags(tutorial-opt) target_link_libraries(tutorial-opt PRIVATE ${LIBS}) -mlir_check_all_link_libraries(tutorial-opt) \ No newline at end of file +mlir_check_all_link_libraries(tutorial-opt) diff --git a/tools/tutorial-opt.cpp b/tools/tutorial-opt.cpp index a4f3ba4..9336d03 100644 --- a/tools/tutorial-opt.cpp +++ b/tools/tutorial-opt.cpp @@ -1,3 +1,4 @@ +#include "lib/Conversion/PolyToStandard/PolyToStandard.h" #include "lib/Dialect/Poly/PolyDialect.h" #include "lib/Transform/Affine/Passes.h" #include "lib/Transform/Arith/Passes.h" @@ -16,6 +17,9 @@ int main(int argc, char **argv) { mlir::tutorial::registerAffinePasses(); mlir::tutorial::registerArithPasses(); + // Dialect conversion passes + mlir::tutorial::poly::registerPolyToStandardPasses(); + return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "Tutorial Pass Driver", registry)); }