Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canonicalizers and Declarative Rewrite Patterns #19

Merged
merged 8 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion lib/Dialect/Poly/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ td_library(
srcs = [
"PolyDialect.td",
"PolyOps.td",
"PolyPatterns.td",
"PolyTypes.td",
],
includes = ["@heir//include"],
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
Expand Down Expand Up @@ -80,6 +80,23 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "canonicalize_inc_gen",
tbl_outs = [
(
["-gen-rewriters"],
"PolyCanonicalize.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "PolyPatterns.td",
deps = [
":td_files",
":types_inc_gen",
"@llvm-project//mlir:ComplexOpsTdFiles",
],
)

cc_library(
name = "Poly",
srcs = [
Expand All @@ -93,9 +110,11 @@ cc_library(
"PolyTypes.h",
],
deps = [
":canonicalize_inc_gen",
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
Expand Down
50 changes: 37 additions & 13 deletions lib/Dialect/Poly/PolyOps.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#include "lib/Dialect/Poly/PolyOps.h"

#include "mlir/Dialect/CommonFolders.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/PatternMatch.h"

// Required after PatternMatch.h
#include "lib/Dialect/Poly/PolyCanonicalize.cpp.inc"

namespace mlir {
namespace tutorial {
Expand All @@ -22,11 +26,10 @@ OpFoldResult SubOp::fold(SubOp::FoldAdaptor adaptor) {
}

OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
auto lhs = dyn_cast<DenseIntElementsAttr>(adaptor.getOperands()[0]);
auto rhs = dyn_cast<DenseIntElementsAttr>(adaptor.getOperands()[1]);
auto lhs = dyn_cast_or_null<DenseIntElementsAttr>(adaptor.getOperands()[0]);
auto rhs = dyn_cast_or_null<DenseIntElementsAttr>(adaptor.getOperands()[1]);

if (!lhs || !rhs)
return nullptr;
if (!lhs || !rhs) return nullptr;

auto degree = getResult().getType().cast<PolynomialType>().getDegreeBound();
auto maxIndex = lhs.size() + rhs.size() - 1;
Expand All @@ -43,7 +46,8 @@ OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
int j = 0;
for (auto rhsIt = rhs.value_begin<APInt>(); rhsIt != rhs.value_end<APInt>();
++rhsIt) {
// index is modulo degree because poly's semantics are defined modulo x^N = 1.
// index is modulo degree because poly's semantics are defined modulo x^N
// = 1.
result[(i + j) % degree] += *rhsIt * (*lhsIt);
++j;
}
Expand All @@ -58,15 +62,35 @@ OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {

OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
// Returns null if the cast failed, which corresponds to a failed fold.
return dyn_cast<DenseIntElementsAttr>(adaptor.getInput());
return dyn_cast_or_null<DenseIntElementsAttr>(adaptor.getInput());
}

LogicalResult EvalOp::verify() {
return getPoint().getType().isSignlessInteger(32)
? success()
: emitOpError("argument point must be a 32-bit integer");
auto pointTy = getPoint().getType();
bool isSignlessInteger = pointTy.isSignlessInteger(32);
auto complexPt = llvm::dyn_cast<ComplexType>(pointTy);
return isSignlessInteger || complexPt ? success()
: emitOpError(
"argument point must be a 32-bit "
"integer, or a complex number");
}

void AddOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
::mlir::MLIRContext *context) {}

void SubOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
::mlir::MLIRContext *context) {
results.add<DifferenceOfSquares>(context);
}

void MulOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
::mlir::MLIRContext *context) {}

void EvalOp::getCanonicalizationPatterns(::mlir::RewritePatternSet &results,
::mlir::MLIRContext *context) {
results.add<LiftConjThroughEval>(context);
}

} // namespace poly
} // namespace tutorial
} // namespace mlir
} // namespace poly
} // namespace tutorial
} // namespace mlir
8 changes: 6 additions & 2 deletions lib/Dialect/Poly/PolyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, Elementwis
let results = (outs PolyOrContainer:$output);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))";
let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Poly_AddOp : Poly_BinOp<"add"> {
Expand All @@ -44,12 +45,15 @@ def Poly_FromTensorOp : Op<Poly_Dialect, "from_tensor", [Pure]> {
let hasFolder = 1;
}

def IntOrComplex : AnyTypeOf<[AnyInteger, AnyComplex]>;

def Poly_EvalOp : Op<Poly_Dialect, "eval", [AllTypesMatch<["point", "output"]>, Has32BitArguments]> {
let summary = "Evaluates a Polynomial at a given input value.";
let arguments = (ins Polynomial:$input, AnyInteger:$point);
let results = (outs AnyInteger:$output);
let arguments = (ins Polynomial:$input, IntOrComplex:$point);
let results = (outs IntOrComplex:$output);
let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)";
let hasVerifier = 1;
let hasCanonicalizer = 1;
}

def Poly_ConstantOp : Op<Poly_Dialect, "constant", [Pure, ConstantLike]> {
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/Poly/PolyPatterns.td
Copy link

@UtkarshKunwar UtkarshKunwar May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following the tutorial and got stuck here e55bab9 with MLIR v18.1.x. I've described the issue on the LLVM repo here in detail. Just wanted to know if someone else has also faced this issue as I couldn't find the solution for it.

EDIT: This issue with the API was fixed later in the tutorial in the commits fa6d030 and d2e0a72 of #26. It has to do with the API changes with Complex operators like ConjOp now requiring two arguments $input and attribute $fastmath.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I haven't seen people use MLIR by installing the tools as separate binaries, and instead they tend to build everything from source at a specific LLVM commit (or HEAD).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must admit I'm not very familiar with the standard MLIR development process. Building a large project like LLVM can take some time on potato computers or in CI/CD setups. So I prefer stable binary releases for such dependencies as they enable faster "clean build" iterations. The dependencies are supposed to be frozen anyway! :)

I also faced problems with some MLIR support tools (LSP servers in particular) when building from source (I filed that issue here).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW We build LLVM from source in the GitHub CI, and I also think that "from source" LLVM build is necessary to build an out-of-tree dialect.
With ccache for the compilation caching, the whole LLVM+MLIR+minimal dialect build process takes about 18min on the free macos-14 M1 runner link.

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef LIB_DIALECT_POLY_POLYPATTERNS_TD_
#define LIB_DIALECT_POLY_POLYPATTERNS_TD_

include "PolyOps.td"
include "mlir/Dialect/Complex/IR/ComplexOps.td"
include "mlir/IR/PatternBase.td"

def LiftConjThroughEval : Pat<
(Poly_EvalOp $f, (ConjOp $z)),
(ConjOp (Poly_EvalOp $f, $z))
>;

def HasOneUse: Constraint<CPred<"$_self.hasOneUse()">, "has one use">;

// Rewrites (x^2 - y^2) as (x+y)(x-y) if x^2 and y^2 have no other uses.
def DifferenceOfSquares : Pattern<
(Poly_SubOp (Poly_MulOp:$lhs $x, $x), (Poly_MulOp:$rhs $y, $y)),
[
(Poly_AddOp:$sum $x, $y),
(Poly_SubOp:$diff $x, $y),
(Poly_MulOp:$res $sum, $diff),
],
[(HasOneUse:$lhs), (HasOneUse:$rhs)]
>;

#endif // LIB_DIALECT_POLY_POLYPATTERNS_TD_
45 changes: 45 additions & 0 deletions tests/poly_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,48 @@ func.func @test_simple() -> !poly.poly<10> {
%4 = poly.add %2, %3 : !poly.poly<10>
return %2 : !poly.poly<10>
}

// CHECK-LABEL: func.func @test_difference_of_squares
// CHECK-SAME: %[[x:.+]]: !poly.poly<3>,
// CHECK-SAME: %[[y:.+]]: !poly.poly<3>
func.func @test_difference_of_squares(
%0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> {
// CHECK: %[[sum:.+]] = poly.add %[[x]], %[[y]]
// CHECK: %[[diff:.+]] = poly.sub %[[x]], %[[y]]
// CHECK: %[[mul:.+]] = poly.mul %[[sum]], %[[diff]]
%2 = poly.mul %0, %0 : !poly.poly<3>
%3 = poly.mul %1, %1 : !poly.poly<3>
%4 = poly.sub %2, %3 : !poly.poly<3>
%5 = poly.add %4, %4 : !poly.poly<3>
return %5 : !poly.poly<3>
}

// CHECK-LABEL: func.func @test_difference_of_squares_other_uses
// CHECK-SAME: %[[x:.+]]: !poly.poly<3>,
// CHECK-SAME: %[[y:.+]]: !poly.poly<3>
func.func @test_difference_of_squares_other_uses(
%0: !poly.poly<3>, %1: !poly.poly<3>) -> !poly.poly<3> {
// The canonicalization does not occur because x_squared has a second use.
// CHECK: %[[x_squared:.+]] = poly.mul %[[x]], %[[x]]
// CHECK: %[[y_squared:.+]] = poly.mul %[[y]], %[[y]]
// CHECK: %[[diff:.+]] = poly.sub %[[x_squared]], %[[y_squared]]
// CHECK: %[[sum:.+]] = poly.add %[[diff]], %[[x_squared]]
%2 = poly.mul %0, %0 : !poly.poly<3>
%3 = poly.mul %1, %1 : !poly.poly<3>
%4 = poly.sub %2, %3 : !poly.poly<3>
%5 = poly.add %4, %2 : !poly.poly<3>
return %5 : !poly.poly<3>
}

// CHECK-LABEL: func.func @test_normalize_conj_through_eval
// CHECK-SAME: %[[f:.+]]: !poly.poly<3>,
// CHECK-SAME: %[[z:.+]]: complex<f64>
func.func @test_normalize_conj_through_eval(
%f: !poly.poly<3>, %z: complex<f64>) -> complex<f64> {
// CHECK: %[[evaled:.+]] = poly.eval %[[f]], %[[z]]
// CHECK-NEXT: %[[eval_bar:.+]] = complex.conj %[[evaled]]
// CHECK-NEXT: return %[[eval_bar]]
%z_bar = complex.conj %z : complex<f64>
%evaled = poly.eval %f, %z_bar : (!poly.poly<3>, complex<f64>) -> complex<f64>
return %evaled : complex<f64>
}
4 changes: 4 additions & 0 deletions tests/poly_syntax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ module {
// CHECK: poly.eval
%6 = poly.eval %4, %5 : (!poly.poly<10>, i32) -> i32

%z = complex.constant [1.0, 2.0] : complex<f64>
// CHECK: poly.eval
%complex_eval = poly.eval %4, %z : (!poly.poly<10>, complex<f64>) -> complex<f64>

%7 = tensor.from_elements %arg0, %arg1 : tensor<2x!poly.poly<10>>
// CHECK: poly.add
%8 = poly.add %7, %7 : tensor<2x!poly.poly<10>>
Expand Down