Skip to content

Commit

Permalink
Folders (#15)
Browse files Browse the repository at this point in the history
* add boilerplate for poly.constant

* add folder for poly.constant

* add failing tests demonstrating folding

* Add binary op folders

* add a folder for FromTensorOp

* add constant materializer to poly dialect
  • Loading branch information
j2kun authored Sep 8, 2023
1 parent 101aac0 commit 7ea3b9c
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 2 deletions.
8 changes: 7 additions & 1 deletion lib/Dialect/Poly/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ td_library(
name = "td_files",
srcs = [
"PolyDialect.td",
"PolyOps.td",
"PolyTypes.td",
],
includes = ["@heir//include"],
deps = [
# the base mlir target for defining operations and dialects in tablegen
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
)
Expand Down Expand Up @@ -80,7 +82,10 @@ gentbl_cc_library(

cc_library(
name = "Poly",
srcs = ["PolyDialect.cpp"],
srcs = [
"PolyDialect.cpp",
"PolyOps.cpp",
],
hdrs = [
"PolyDialect.h",
"PolyOps.h",
Expand All @@ -90,6 +95,7 @@ cc_library(
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
Expand Down
10 changes: 9 additions & 1 deletion lib/Dialect/Poly/PolyDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#include "lib/Dialect/Poly/PolyOps.h"
#include "lib/Dialect/Poly/PolyTypes.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h"

#include "lib/Dialect/Poly/PolyDialect.cpp.inc"
#define GET_TYPEDEF_CLASSES
Expand All @@ -26,6 +26,14 @@ void PolyDialect::initialize() {
>();
}

Operation *PolyDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
auto coeffs = dyn_cast<DenseIntElementsAttr>(value);
if (!coeffs)
return nullptr;
return builder.create<ConstantOp>(loc, type, coeffs);
}

} // namespace poly
} // namespace tutorial
} // namespace mlir
1 change: 1 addition & 0 deletions lib/Dialect/Poly/PolyDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def Poly_Dialect : Dialect {
let cppNamespace = "::mlir::tutorial::poly";

let useDefaultTypePrinterParser = 1;
let hasConstantMaterializer = 1;
}

#endif // LIB_DIALECT_POLY_POLYDIALECT_TD_
66 changes: 66 additions & 0 deletions lib/Dialect/Poly/PolyOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "lib/Dialect/Poly/PolyOps.h"

#include "mlir/Dialect/CommonFolders.h"
#include "llvm/Support/Debug.h"

namespace mlir {
namespace tutorial {
namespace poly {

OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) {
return adaptor.getCoefficients();
}

OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) {
return constFoldBinaryOp<IntegerAttr, APInt>(
adaptor.getOperands(), [&](APInt a, APInt b) { return a + b; });
}

OpFoldResult SubOp::fold(SubOp::FoldAdaptor adaptor) {
return constFoldBinaryOp<IntegerAttr, APInt>(
adaptor.getOperands(), [&](APInt a, APInt b) { return a - b; });
}

OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
auto lhs = dyn_cast<DenseIntElementsAttr>(adaptor.getOperands()[0]);
auto rhs = dyn_cast<DenseIntElementsAttr>(adaptor.getOperands()[1]);

if (!lhs || !rhs)
return nullptr;

auto degree = getResult().getType().cast<PolynomialType>().getDegreeBound();
auto maxIndex = lhs.size() + rhs.size() - 1;

SmallVector<APInt, 8> result;
result.reserve(maxIndex);
for (int i = 0; i < maxIndex; ++i) {
result.push_back(APInt((*lhs.begin()).getBitWidth(), 0));
}

int i = 0;
for (auto lhsIt = lhs.value_begin<APInt>(); lhsIt != lhs.value_end<APInt>();
++lhsIt) {
int j = 0;
for (auto rhsIt = rhs.value_begin<APInt>(); rhsIt != rhs.value_end<APInt>();
++rhsIt) {
// index is modulo degree because poly's semantics are defined modulo x^N = 1.
result[(i + j) % degree] += *rhsIt * (*lhsIt);
++j;
}
++i;
}

return DenseIntElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(result.size()),
IntegerType::get(getContext(), 32)),
result);
}

OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
// Returns null if the cast failed, which corresponds to a failed fold.
return dyn_cast<DenseIntElementsAttr>(adaptor.getInput());
}

} // namespace poly
} // namespace tutorial
} // namespace mlir
12 changes: 12 additions & 0 deletions lib/Dialect/Poly/PolyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

include "PolyDialect.td"
include "PolyTypes.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Type constraint for poly binop arguments: polys, vectors of polys, or
Expand All @@ -13,6 +15,7 @@ class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, Elementwis
let arguments = (ins PolyOrContainer:$lhs, PolyOrContainer:$rhs);
let results = (outs PolyOrContainer:$output);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)";
let hasFolder = 1;
}

def Poly_AddOp : Poly_BinOp<"add"> {
Expand All @@ -32,6 +35,7 @@ def Poly_FromTensorOp : Op<Poly_Dialect, "from_tensor", [Pure]> {
let arguments = (ins TensorOf<[AnyInteger]>:$input);
let results = (outs Polynomial:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
let hasFolder = 1;
}

def Poly_EvalOp : Op<Poly_Dialect, "eval"> {
Expand All @@ -41,5 +45,13 @@ def Poly_EvalOp : Op<Poly_Dialect, "eval"> {
let assemblyFormat = "$input `,` $point attr-dict `:` `(` type($input) `,` type($point) `)` `->` type($output)";
}

def Poly_ConstantOp : Op<Poly_Dialect, "constant", [Pure, ConstantLike]> {
let summary = "Define a constant polynomial via an attribute.";
let arguments = (ins AnyIntElementsAttr:$coefficients);
let results = (outs Polynomial:$output);
let assemblyFormat = "$coefficients attr-dict `:` type($output)";
let hasFolder = 1;
}


#endif // LIB_DIALECT_POLY_POLYOPS_TD_
13 changes: 13 additions & 0 deletions tests/poly_canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: tutorial-opt --canonicalize %s | FileCheck %s

// CHECK-LABEL: @test_simple
func.func @test_simple() -> !poly.poly<10> {
// CHECK: poly.constant dense<[2, 4, 6]>
// CHECK-NEXT: return
%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>
%2 = poly.add %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
return %2 : !poly.poly<10>
}
6 changes: 6 additions & 0 deletions tests/poly_syntax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ module {
// CHECK: poly.add
%9 = poly.add %7, %4 : (tensor<2x!poly.poly<10>>, !poly.poly<10>) -> tensor<2x!poly.poly<10>>

// CHECK: poly.constant
%10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
%11 = poly.constant dense<[2, 3, 4]> : tensor<3xi8> : !poly.poly<10>
%12 = poly.constant dense<"0x020304"> : tensor<3xi8> : !poly.poly<10>
%13 = poly.constant dense<4> : tensor<100xi32> : !poly.poly<10>

return %4 : !poly.poly<10>
}
}
35 changes: 35 additions & 0 deletions tests/sccp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: tutorial-opt -pass-pipeline="builtin.module(func.func(sccp))" %s | FileCheck %s

// Note how sscp creates new constants for the computed values,
// though it does not remove the dead code.

// CHECK-LABEL: @test_arith_sccp
// CHECK-NEXT: %[[v0:.*]] = arith.constant 63 : i32
// CHECK-NEXT: %[[v1:.*]] = arith.constant 49 : i32
// CHECK-NEXT: %[[v2:.*]] = arith.constant 14 : i32
// CHECK-NEXT: %[[v3:.*]] = arith.constant 8 : i32
// CHECK-NEXT: %[[v4:.*]] = arith.constant 7 : i32
// CHECK-NEXT: return %[[v2]] : i32
func.func @test_arith_sccp() -> i32 {
%0 = arith.constant 7 : i32
%1 = arith.constant 8 : i32
%2 = arith.addi %0, %0 : i32
%3 = arith.muli %0, %0 : i32
%4 = arith.addi %2, %3 : i32
return %2 : i32
}

// CHECK-LABEL: @test_poly_sccp
func.func @test_poly_sccp() -> !poly.poly<10> {
%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>
// CHECK: poly.constant dense<[2, 8, 20, 24, 18]>
// CHECK: poly.constant dense<[1, 4, 10, 12, 9]>
// CHECK: poly.constant dense<[1, 2, 3]>
// CHECK-NOT: poly.mul
// CHECK-NOT: poly.add
%2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
return %2 : !poly.poly<10>
}

0 comments on commit 7ea3b9c

Please sign in to comment.