diff --git a/lib/Dialect/Poly/BUILD b/lib/Dialect/Poly/BUILD index 9dfb6d4..5a4d699 100644 --- a/lib/Dialect/Poly/BUILD +++ b/lib/Dialect/Poly/BUILD @@ -9,9 +9,9 @@ td_library( srcs = [ "PolyDialect.td", "PolyOps.td", + "PolyPatterns.td", "PolyTypes.td", ], - includes = ["@heir//include"], deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", @@ -80,6 +80,23 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "canonicalize_inc_gen", + tbl_outs = [ + ( + ["-gen-rewriters"], + "PolyCanonicalize.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "PolyPatterns.td", + deps = [ + ":td_files", + ":types_inc_gen", + "@llvm-project//mlir:ComplexOpsTdFiles", + ], +) + cc_library( name = "Poly", srcs = [ @@ -93,9 +110,11 @@ cc_library( "PolyTypes.h", ], deps = [ + ":canonicalize_inc_gen", ":dialect_inc_gen", ":ops_inc_gen", ":types_inc_gen", + "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", diff --git a/lib/Dialect/Poly/PolyOps.cpp b/lib/Dialect/Poly/PolyOps.cpp index b92c58a..473f2d4 100644 --- a/lib/Dialect/Poly/PolyOps.cpp +++ b/lib/Dialect/Poly/PolyOps.cpp @@ -1,7 +1,11 @@ #include "lib/Dialect/Poly/PolyOps.h" #include "mlir/Dialect/CommonFolders.h" -#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/IR/PatternMatch.h" + +// Required after PatternMatch.h +#include "lib/Dialect/Poly/PolyCanonicalize.cpp.inc" namespace mlir { namespace tutorial { @@ -22,11 +26,10 @@ OpFoldResult SubOp::fold(SubOp::FoldAdaptor adaptor) { } OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) { - auto lhs = dyn_cast(adaptor.getOperands()[0]); - auto rhs = dyn_cast(adaptor.getOperands()[1]); + auto lhs = dyn_cast_or_null(adaptor.getOperands()[0]); + auto rhs = dyn_cast_or_null(adaptor.getOperands()[1]); - if (!lhs || !rhs) - return nullptr; + if (!lhs || !rhs) return nullptr; auto degree = getResult().getType().cast().getDegreeBound(); auto maxIndex = lhs.size() + rhs.size() - 1; @@ -43,7 +46,8 @@ OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) { int j = 0; for (auto rhsIt = rhs.value_begin(); rhsIt != rhs.value_end(); ++rhsIt) { - // index is modulo degree because poly's semantics are defined modulo x^N = 1. + // index is modulo degree because poly's semantics are defined modulo x^N + // = 1. result[(i + j) % degree] += *rhsIt * (*lhsIt); ++j; } @@ -58,15 +62,35 @@ OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) { OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) { // Returns null if the cast failed, which corresponds to a failed fold. - return dyn_cast(adaptor.getInput()); + return dyn_cast_or_null(adaptor.getInput()); } LogicalResult EvalOp::verify() { - return getPoint().getType().isSignlessInteger(32) - ? success() - : emitOpError("argument point must be a 32-bit integer"); + auto pointTy = getPoint().getType(); + bool isSignlessInteger = pointTy.isSignlessInteger(32); + auto complexPt = llvm::dyn_cast(pointTy); + return isSignlessInteger || complexPt ? success() + : emitOpError( + "argument point must be a 32-bit " + "integer, or a complex number"); +} + +void AddOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, + ::mlir::MLIRContext *context) {} + +void SubOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, + ::mlir::MLIRContext *context) { + results.add(context); +} + +void MulOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, + ::mlir::MLIRContext *context) {} + +void EvalOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results, + ::mlir::MLIRContext *context) { + results.add(context); } -} // namespace poly -} // namespace tutorial -} // namespace mlir +} // namespace poly +} // namespace tutorial +} // namespace mlir diff --git a/lib/Dialect/Poly/PolyOps.td b/lib/Dialect/Poly/PolyOps.td index 520ab09..2106bb9 100644 --- a/lib/Dialect/Poly/PolyOps.td +++ b/lib/Dialect/Poly/PolyOps.td @@ -22,6 +22,7 @@ class Poly_BinOp : Op { @@ -44,12 +45,15 @@ def Poly_FromTensorOp : Op { let hasFolder = 1; } +def IntOrComplex : AnyTypeOf<[AnyInteger, AnyComplex]>; + def Poly_EvalOp : Op, Has32BitArguments]> { let summary = "Evaluates a Polynomial at a given input value."; - let arguments = (ins Polynomial:$input, AnyInteger:$point); - let results = (outs AnyInteger:$output); + let arguments = (ins Polynomial:$input, IntOrComplex:$point); + let results = (outs IntOrComplex:$output); let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)"; let hasVerifier = 1; + let hasCanonicalizer = 1; } def Poly_ConstantOp : Op { diff --git a/lib/Dialect/Poly/PolyPatterns.td b/lib/Dialect/Poly/PolyPatterns.td new file mode 100644 index 0000000..442a1ba --- /dev/null +++ b/lib/Dialect/Poly/PolyPatterns.td @@ -0,0 +1,26 @@ +#ifndef LIB_DIALECT_POLY_POLYPATTERNS_TD_ +#define LIB_DIALECT_POLY_POLYPATTERNS_TD_ + +include "PolyOps.td" +include "mlir/Dialect/Complex/IR/ComplexOps.td" +include "mlir/IR/PatternBase.td" + +def LiftConjThroughEval : Pat< + (Poly_EvalOp $f, (ConjOp $z)), + (ConjOp (Poly_EvalOp $f, $z)) +>; + +def HasOneUse: Constraint, "has one use">; + +// Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses. +def DifferenceOfSquares : Pattern< + (Poly_SubOp (Poly_MulOp:$lhs $x, $x), (Poly_MulOp:$rhs $y, $y)), + [ + (Poly_AddOp:$sum $x, $y), + (Poly_SubOp:$diff $x, $y), + (Poly_MulOp:$res $sum, $diff), + ], + [(HasOneUse:$lhs), (HasOneUse:$rhs)] +>; + +#endif // LIB_DIALECT_POLY_POLYPATTERNS_TD_ diff --git a/tests/poly_canonicalize.mlir b/tests/poly_canonicalize.mlir index 0b219ab..a1c4cdf 100644 --- a/tests/poly_canonicalize.mlir +++ b/tests/poly_canonicalize.mlir @@ -11,3 +11,48 @@ func.func @test_simple() -> !poly.poly<10> { %4 = poly.add %2, %3 : !poly.poly<10> return %2 : !poly.poly<10> } + +// CHECK-LABEL: func.func @test_difference_of_squares +// CHECK-SAME: %[[x:.+]]: !poly.poly<3>, +// CHECK-SAME: %[[y:.+]]: !poly.poly<3> +func.func @test_difference_of_squares( + %0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> { + // CHECK: %[[sum:.+]] = poly.add %[[x]], %[[y]] + // CHECK: %[[diff:.+]] = poly.sub %[[x]], %[[y]] + // CHECK: %[[mul:.+]] = poly.mul %[[sum]], %[[diff]] + %2 = poly.mul %0, %0 : !poly.poly<3> + %3 = poly.mul %1, %1 : !poly.poly<3> + %4 = poly.sub %2, %3 : !poly.poly<3> + %5 = poly.add %4, %4 : !poly.poly<3> + return %5 : !poly.poly<3> +} + +// CHECK-LABEL: func.func @test_difference_of_squares_other_uses +// CHECK-SAME: %[[x:.+]]: !poly.poly<3>, +// CHECK-SAME: %[[y:.+]]: !poly.poly<3> +func.func @test_difference_of_squares_other_uses( + %0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> { + // The canonicalization does not occur because x_squared has a second use. + // CHECK: %[[x_squared:.+]] = poly.mul %[[x]], %[[x]] + // CHECK: %[[y_squared:.+]] = poly.mul %[[y]], %[[y]] + // CHECK: %[[diff:.+]] = poly.sub %[[x_squared]], %[[y_squared]] + // CHECK: %[[sum:.+]] = poly.add %[[diff]], %[[x_squared]] + %2 = poly.mul %0, %0 : !poly.poly<3> + %3 = poly.mul %1, %1 : !poly.poly<3> + %4 = poly.sub %2, %3 : !poly.poly<3> + %5 = poly.add %4, %2 : !poly.poly<3> + return %5 : !poly.poly<3> +} + +// CHECK-LABEL: func.func @test_normalize_conj_through_eval +// CHECK-SAME: %[[f:.+]]: !poly.poly<3>, +// CHECK-SAME: %[[z:.+]]: complex +func.func @test_normalize_conj_through_eval( + %f: !poly.poly<3>, %z: complex) -> complex { + // CHECK: %[[evaled:.+]] = poly.eval %[[f]], %[[z]] + // CHECK-NEXT: %[[eval_bar:.+]] = complex.conj %[[evaled]] + // CHECK-NEXT: return %[[eval_bar]] + %z_bar = complex.conj %z : complex + %evaled = poly.eval %f, %z_bar : (!poly.poly<3>, complex) -> complex + return %evaled : complex +} diff --git a/tests/poly_syntax.mlir b/tests/poly_syntax.mlir index 8ab6db8..f3c7f5e 100644 --- a/tests/poly_syntax.mlir +++ b/tests/poly_syntax.mlir @@ -25,6 +25,10 @@ module { // CHECK: poly.eval %6 = poly.eval %4, %5 : (!poly.poly<10>, i32) -> i32 + %z = complex.constant [1.0, 2.0] : complex + // CHECK: poly.eval + %complex_eval = poly.eval %4, %z : (!poly.poly<10>, complex) -> complex + %7 = tensor.from_elements %arg0, %arg1 : tensor<2x!poly.poly<10>> // CHECK: poly.add %8 = poly.add %7, %7 : tensor<2x!poly.poly<10>>