Skip to content

Commit

Permalink
Canonicalizers and Declarative Rewrite Patterns (#19)
Browse files Browse the repository at this point in the history
* slight formatting changes

* Add empty canonicalizers to poly binops

* dyn_cast -> dyn_cast_or_null

* Canonicalize a difference of squares

* upgrade eval to accept complex inputs

* Add tablegen pattern to lift conj through eval

* add generated patterns to eval canonicalizer

* rewrite DifferenceOfSquares in tablegen
  • Loading branch information
j2kun authored Sep 19, 2023
1 parent 822c84d commit f25c12a
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 16 deletions.
21 changes: 20 additions & 1 deletion lib/Dialect/Poly/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ td_library(
srcs = [
"PolyDialect.td",
"PolyOps.td",
"PolyPatterns.td",
"PolyTypes.td",
],
includes = ["@heir//include"],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
Expand Down Expand Up @@ -80,6 +80,23 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "canonicalize_inc_gen",
tbl_outs = [
(
["-gen-rewriters"],
"PolyCanonicalize.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "PolyPatterns.td",
deps = [
":td_files",
":types_inc_gen",
"@llvm-project//mlir:ComplexOpsTdFiles",
],
)

cc_library(
name = "Poly",
srcs = [
Expand All @@ -93,9 +110,11 @@ cc_library(
"PolyTypes.h",
],
deps = [
":canonicalize_inc_gen",
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
Expand Down
50 changes: 37 additions & 13 deletions lib/Dialect/Poly/PolyOps.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#include "lib/Dialect/Poly/PolyOps.h"

#include "mlir/Dialect/CommonFolders.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/PatternMatch.h"

// Required after PatternMatch.h
#include "lib/Dialect/Poly/PolyCanonicalize.cpp.inc"

namespace mlir {
namespace tutorial {
Expand All @@ -22,11 +26,10 @@ OpFoldResult SubOp::fold(SubOp::FoldAdaptor adaptor) {
}

OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
auto lhs = dyn_cast<DenseIntElementsAttr>(adaptor.getOperands()[0]);
auto rhs = dyn_cast<DenseIntElementsAttr>(adaptor.getOperands()[1]);
auto lhs = dyn_cast_or_null<DenseIntElementsAttr>(adaptor.getOperands()[0]);
auto rhs = dyn_cast_or_null<DenseIntElementsAttr>(adaptor.getOperands()[1]);

if (!lhs || !rhs)
return nullptr;
if (!lhs || !rhs) return nullptr;

auto degree = getResult().getType().cast<PolynomialType>().getDegreeBound();
auto maxIndex = lhs.size() + rhs.size() - 1;
Expand All @@ -43,7 +46,8 @@ OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
int j = 0;
for (auto rhsIt = rhs.value_begin<APInt>(); rhsIt != rhs.value_end<APInt>();
++rhsIt) {
// index is modulo degree because poly's semantics are defined modulo x^N = 1.
// index is modulo degree because poly's semantics are defined modulo x^N
// = 1.
result[(i + j) % degree] += *rhsIt * (*lhsIt);
++j;
}
Expand All @@ -58,15 +62,35 @@ OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {

OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
// Returns null if the cast failed, which corresponds to a failed fold.
return dyn_cast<DenseIntElementsAttr>(adaptor.getInput());
return dyn_cast_or_null<DenseIntElementsAttr>(adaptor.getInput());
}

LogicalResult EvalOp::verify() {
return getPoint().getType().isSignlessInteger(32)
? success()
: emitOpError("argument point must be a 32-bit integer");
auto pointTy = getPoint().getType();
bool isSignlessInteger = pointTy.isSignlessInteger(32);
auto complexPt = llvm::dyn_cast<ComplexType>(pointTy);
return isSignlessInteger || complexPt ? success()
: emitOpError(
"argument point must be a 32-bit "
"integer, or a complex number");
}

void AddOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
::mlir::MLIRContext *context) {}

void SubOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
::mlir::MLIRContext *context) {
results.add<DifferenceOfSquares>(context);
}

void MulOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
::mlir::MLIRContext *context) {}

void EvalOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
::mlir::MLIRContext *context) {
results.add<LiftConjThroughEval>(context);
}

} // namespace poly
} // namespace tutorial
} // namespace mlir
} // namespace poly
} // namespace tutorial
} // namespace mlir
8 changes: 6 additions & 2 deletions lib/Dialect/Poly/PolyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, Elementwis
let results = (outs PolyOrContainer:$output);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))";
let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Poly_AddOp : Poly_BinOp<"add"> {
Expand All @@ -44,12 +45,15 @@ def Poly_FromTensorOp : Op<Poly_Dialect, "from_tensor", [Pure]> {
let hasFolder = 1;
}

def IntOrComplex : AnyTypeOf<[AnyInteger, AnyComplex]>;

def Poly_EvalOp : Op<Poly_Dialect, "eval", [AllTypesMatch<["point", "output"]>, Has32BitArguments]> {
let summary = "Evaluates a Polynomial at a given input value.";
let arguments = (ins Polynomial:$input, AnyInteger:$point);
let results = (outs AnyInteger:$output);
let arguments = (ins Polynomial:$input, IntOrComplex:$point);
let results = (outs IntOrComplex:$output);
let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)";
let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def Poly_ConstantOp : Op<Poly_Dialect, "constant", [Pure, ConstantLike]> {
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/Poly/PolyPatterns.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef LIB_DIALECT_POLY_POLYPATTERNS_TD_
#define LIB_DIALECT_POLY_POLYPATTERNS_TD_

include "PolyOps.td"
include "mlir/Dialect/Complex/IR/ComplexOps.td"
include "mlir/IR/PatternBase.td"

def LiftConjThroughEval : Pat<
(Poly_EvalOp $f, (ConjOp $z)),
(ConjOp (Poly_EvalOp $f, $z))
>;

def HasOneUse: Constraint<CPred<"$_self.hasOneUse()">, "has one use">;

// Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses.
def DifferenceOfSquares : Pattern<
(Poly_SubOp (Poly_MulOp:$lhs $x, $x), (Poly_MulOp:$rhs $y, $y)),
[
(Poly_AddOp:$sum $x, $y),
(Poly_SubOp:$diff $x, $y),
(Poly_MulOp:$res $sum, $diff),
],
[(HasOneUse:$lhs), (HasOneUse:$rhs)]
>;

#endif // LIB_DIALECT_POLY_POLYPATTERNS_TD_
45 changes: 45 additions & 0 deletions tests/poly_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,48 @@ func.func @test_simple() -> !poly.poly<10> {
%4 = poly.add %2, %3 : !poly.poly<10>
return %2 : !poly.poly<10>
}

// CHECK-LABEL: func.func @test_difference_of_squares
// CHECK-SAME: %[[x:.+]]: !poly.poly<3>,
// CHECK-SAME: %[[y:.+]]: !poly.poly<3>
func.func @test_difference_of_squares(
%0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> {
// CHECK: %[[sum:.+]] = poly.add %[[x]], %[[y]]
// CHECK: %[[diff:.+]] = poly.sub %[[x]], %[[y]]
// CHECK: %[[mul:.+]] = poly.mul %[[sum]], %[[diff]]
%2 = poly.mul %0, %0 : !poly.poly<3>
%3 = poly.mul %1, %1 : !poly.poly<3>
%4 = poly.sub %2, %3 : !poly.poly<3>
%5 = poly.add %4, %4 : !poly.poly<3>
return %5 : !poly.poly<3>
}

// CHECK-LABEL: func.func @test_difference_of_squares_other_uses
// CHECK-SAME: %[[x:.+]]: !poly.poly<3>,
// CHECK-SAME: %[[y:.+]]: !poly.poly<3>
func.func @test_difference_of_squares_other_uses(
%0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> {
// The canonicalization does not occur because x_squared has a second use.
// CHECK: %[[x_squared:.+]] = poly.mul %[[x]], %[[x]]
// CHECK: %[[y_squared:.+]] = poly.mul %[[y]], %[[y]]
// CHECK: %[[diff:.+]] = poly.sub %[[x_squared]], %[[y_squared]]
// CHECK: %[[sum:.+]] = poly.add %[[diff]], %[[x_squared]]
%2 = poly.mul %0, %0 : !poly.poly<3>
%3 = poly.mul %1, %1 : !poly.poly<3>
%4 = poly.sub %2, %3 : !poly.poly<3>
%5 = poly.add %4, %2 : !poly.poly<3>
return %5 : !poly.poly<3>
}

// CHECK-LABEL: func.func @test_normalize_conj_through_eval
// CHECK-SAME: %[[f:.+]]: !poly.poly<3>,
// CHECK-SAME: %[[z:.+]]: complex<f64>
func.func @test_normalize_conj_through_eval(
%f: !poly.poly<3>, %z: complex<f64>) -> complex<f64> {
// CHECK: %[[evaled:.+]] = poly.eval %[[f]], %[[z]]
// CHECK-NEXT: %[[eval_bar:.+]] = complex.conj %[[evaled]]
// CHECK-NEXT: return %[[eval_bar]]
%z_bar = complex.conj %z : complex<f64>
%evaled = poly.eval %f, %z_bar : (!poly.poly<3>, complex<f64>) -> complex<f64>
return %evaled : complex<f64>
}
4 changes: 4 additions & 0 deletions tests/poly_syntax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ module {
// CHECK: poly.eval
%6 = poly.eval %4, %5 : (!poly.poly<10>, i32) -> i32

%z = complex.constant [1.0, 2.0] : complex<f64>
// CHECK: poly.eval
%complex_eval = poly.eval %4, %z : (!poly.poly<10>, complex<f64>) -> complex<f64>

%7 = tensor.from_elements %arg0, %arg1 : tensor<2x!poly.poly<10>>
// CHECK: poly.add
%8 = poly.add %7, %7 : tensor<2x!poly.poly<10>>
Expand Down

0 comments on commit f25c12a

Please sign in to comment.