diff --git a/lib/Dialect/Poly/BUILD b/lib/Dialect/Poly/BUILD index 0651282..c1c36f7 100644 --- a/lib/Dialect/Poly/BUILD +++ b/lib/Dialect/Poly/BUILD @@ -8,12 +8,14 @@ td_library( name = "td_files", srcs = [ "PolyDialect.td", + "PolyOps.td", "PolyTypes.td", ], includes = ["@heir//include"], deps = [ # the base mlir target for defining operations and dialects in tablegen "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", ], ) @@ -80,7 +82,10 @@ gentbl_cc_library( cc_library( name = "Poly", - srcs = ["PolyDialect.cpp"], + srcs = [ + "PolyDialect.cpp", + "PolyOps.cpp", + ], hdrs = [ "PolyDialect.h", "PolyOps.h", @@ -90,6 +95,7 @@ cc_library( ":dialect_inc_gen", ":ops_inc_gen", ":types_inc_gen", + "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], diff --git a/lib/Dialect/Poly/PolyDialect.cpp b/lib/Dialect/Poly/PolyDialect.cpp index 1529c97..291cb12 100644 --- a/lib/Dialect/Poly/PolyDialect.cpp +++ b/lib/Dialect/Poly/PolyDialect.cpp @@ -2,8 +2,8 @@ #include "lib/Dialect/Poly/PolyOps.h" #include "lib/Dialect/Poly/PolyTypes.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" #include "mlir/include/mlir/IR/Builders.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" #include "lib/Dialect/Poly/PolyDialect.cpp.inc" #define GET_TYPEDEF_CLASSES @@ -26,6 +26,14 @@ void PolyDialect::initialize() { >(); } +Operation *PolyDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + auto coeffs = dyn_cast(value); + if (!coeffs) + return nullptr; + return builder.create(loc, type, coeffs); +} + } // namespace poly } // namespace tutorial } // namespace mlir diff --git a/lib/Dialect/Poly/PolyDialect.td b/lib/Dialect/Poly/PolyDialect.td index f00e35f..c0214dd 100644 --- a/lib/Dialect/Poly/PolyDialect.td +++ b/lib/Dialect/Poly/PolyDialect.td @@ -14,6 +14,7 @@ def Poly_Dialect : Dialect { let cppNamespace = "::mlir::tutorial::poly"; let useDefaultTypePrinterParser = 1; + let hasConstantMaterializer = 1; } #endif // LIB_DIALECT_POLY_POLYDIALECT_TD_ diff --git a/lib/Dialect/Poly/PolyOps.cpp b/lib/Dialect/Poly/PolyOps.cpp new file mode 100644 index 0000000..b54b117 --- /dev/null +++ b/lib/Dialect/Poly/PolyOps.cpp @@ -0,0 +1,66 @@ +#include "lib/Dialect/Poly/PolyOps.h" + +#include "mlir/Dialect/CommonFolders.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace tutorial { +namespace poly { + +OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) { + return adaptor.getCoefficients(); +} + +OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) { + return constFoldBinaryOp( + adaptor.getOperands(), [&](APInt a, APInt b) { return a + b; }); +} + +OpFoldResult SubOp::fold(SubOp::FoldAdaptor adaptor) { + return constFoldBinaryOp( + adaptor.getOperands(), [&](APInt a, APInt b) { return a - b; }); +} + +OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) { + auto lhs = dyn_cast(adaptor.getOperands()[0]); + auto rhs = dyn_cast(adaptor.getOperands()[1]); + + if (!lhs || !rhs) + return nullptr; + + auto degree = getResult().getType().cast().getDegreeBound(); + auto maxIndex = lhs.size() + rhs.size() - 1; + + SmallVector result; + result.reserve(maxIndex); + for (int i = 0; i < maxIndex; ++i) { + result.push_back(APInt((*lhs.begin()).getBitWidth(), 0)); + } + + int i = 0; + for (auto lhsIt = lhs.value_begin(); lhsIt != lhs.value_end(); + ++lhsIt) { + int j = 0; + for (auto rhsIt = rhs.value_begin(); rhsIt != rhs.value_end(); + ++rhsIt) { + // index is modulo degree because poly's semantics are defined modulo x^N = 1. + result[(i + j) % degree] += *rhsIt * (*lhsIt); + ++j; + } + ++i; + } + + return DenseIntElementsAttr::get( + RankedTensorType::get(static_cast(result.size()), + IntegerType::get(getContext(), 32)), + result); +} + +OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) { + // Returns null if the cast failed, which corresponds to a failed fold. + return dyn_cast(adaptor.getInput()); +} + +} // namespace poly +} // namespace tutorial +} // namespace mlir diff --git a/lib/Dialect/Poly/PolyOps.td b/lib/Dialect/Poly/PolyOps.td index 5755f03..f5e1dcd 100644 --- a/lib/Dialect/Poly/PolyOps.td +++ b/lib/Dialect/Poly/PolyOps.td @@ -3,6 +3,8 @@ include "PolyDialect.td" include "PolyTypes.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Type constraint for poly binop arguments: polys, vectors of polys, or @@ -13,6 +15,7 @@ class Poly_BinOp : Op { @@ -32,6 +35,7 @@ def Poly_FromTensorOp : Op { let arguments = (ins TensorOf<[AnyInteger]>:$input); let results = (outs Polynomial:$output); let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; + let hasFolder = 1; } def Poly_EvalOp : Op { @@ -41,5 +45,13 @@ def Poly_EvalOp : Op { let assemblyFormat = "$input `,` $point attr-dict `:` `(` type($input) `,` type($point) `)` `->` type($output)"; } +def Poly_ConstantOp : Op { + let summary = "Define a constant polynomial via an attribute."; + let arguments = (ins AnyIntElementsAttr:$coefficients); + let results = (outs Polynomial:$output); + let assemblyFormat = "$coefficients attr-dict `:` type($output)"; + let hasFolder = 1; +} + #endif // LIB_DIALECT_POLY_POLYOPS_TD_ diff --git a/tests/poly_canonicalize.mlir b/tests/poly_canonicalize.mlir new file mode 100644 index 0000000..8d8636a --- /dev/null +++ b/tests/poly_canonicalize.mlir @@ -0,0 +1,13 @@ +// RUN: tutorial-opt --canonicalize %s | FileCheck %s + +// CHECK-LABEL: @test_simple +func.func @test_simple() -> !poly.poly<10> { + // CHECK: poly.constant dense<[2, 4, 6]> + // CHECK-NEXT: return + %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> + %p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10> + %2 = poly.add %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + return %2 : !poly.poly<10> +} diff --git a/tests/poly_syntax.mlir b/tests/poly_syntax.mlir index 29515b2..0eb66b3 100644 --- a/tests/poly_syntax.mlir +++ b/tests/poly_syntax.mlir @@ -31,6 +31,12 @@ module { // CHECK: poly.add %9 = poly.add %7, %4 : (tensor<2x!poly.poly<10>>, !poly.poly<10>) -> tensor<2x!poly.poly<10>> + // CHECK: poly.constant + %10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> + %11 = poly.constant dense<[2, 3, 4]> : tensor<3xi8> : !poly.poly<10> + %12 = poly.constant dense<"0x020304"> : tensor<3xi8> : !poly.poly<10> + %13 = poly.constant dense<4> : tensor<100xi32> : !poly.poly<10> + return %4 : !poly.poly<10> } } diff --git a/tests/sccp.mlir b/tests/sccp.mlir new file mode 100644 index 0000000..39b3a24 --- /dev/null +++ b/tests/sccp.mlir @@ -0,0 +1,35 @@ +// RUN: tutorial-opt -pass-pipeline="builtin.module(func.func(sccp))" %s | FileCheck %s + +// Note how sscp creates new constants for the computed values, +// though it does not remove the dead code. + +// CHECK-LABEL: @test_arith_sccp +// CHECK-NEXT: %[[v0:.*]] = arith.constant 63 : i32 +// CHECK-NEXT: %[[v1:.*]] = arith.constant 49 : i32 +// CHECK-NEXT: %[[v2:.*]] = arith.constant 14 : i32 +// CHECK-NEXT: %[[v3:.*]] = arith.constant 8 : i32 +// CHECK-NEXT: %[[v4:.*]] = arith.constant 7 : i32 +// CHECK-NEXT: return %[[v2]] : i32 +func.func @test_arith_sccp() -> i32 { + %0 = arith.constant 7 : i32 + %1 = arith.constant 8 : i32 + %2 = arith.addi %0, %0 : i32 + %3 = arith.muli %0, %0 : i32 + %4 = arith.addi %2, %3 : i32 + return %2 : i32 +} + +// CHECK-LABEL: @test_poly_sccp +func.func @test_poly_sccp() -> !poly.poly<10> { + %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> + %p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10> + // CHECK: poly.constant dense<[2, 8, 20, 24, 18]> + // CHECK: poly.constant dense<[1, 4, 10, 12, 9]> + // CHECK: poly.constant dense<[1, 2, 3]> + // CHECK-NOT: poly.mul + // CHECK-NOT: poly.add + %2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + return %2 : !poly.poly<10> +}