Skip to content

Commit

Permalink
Verifiers (#17)
Browse files Browse the repository at this point in the history
* Add `qualified` to the type assembly format

This is because I noticed the assembly format was printing <10> for
polynomials instead of the fully qualified type name. After this commit
it will print the whole type

See https://mlir.llvm.org/docs/DefiningDialects/Operations/#declarative-assembly-format
for more details.

* Add SameOperandsAndResultType

This removes the flexibility of having mixed poly + tensor ops for the
binary operations, but demonstrates how the type inference engine
enables a more succinct textual IR.

If you were to simplify the assembly format without doing this, you'd
get a compile-time error complaining that it can't infer the type of the
operands or argument.

* add AllTypesMatch to EvalOp

* add a custom verifier for evalop

* add verifier via trait
  • Loading branch information
j2kun authored Sep 12, 2023
1 parent c387ac0 commit 031e5fe
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 31 deletions.
6 changes: 4 additions & 2 deletions lib/Dialect/Poly/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ td_library(
],
includes = ["@heir//include"],
deps = [
# the base mlir target for defining operations and dialects in tablegen
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
],
)
Expand Down Expand Up @@ -89,6 +89,7 @@ cc_library(
hdrs = [
"PolyDialect.h",
"PolyOps.h",
"PolyTraits.h",
"PolyTypes.h",
],
deps = [
Expand All @@ -97,6 +98,7 @@ cc_library(
":types_inc_gen",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Support",
],
)
6 changes: 6 additions & 0 deletions lib/Dialect/Poly/PolyOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
return dyn_cast<DenseIntElementsAttr>(adaptor.getInput());
}

LogicalResult EvalOp::verify() {
return getPoint().getType().isSignlessInteger(32)
? success()
: emitOpError("argument point must be a 32-bit integer");
}

} // namespace poly
} // namespace tutorial
} // namespace mlir
10 changes: 6 additions & 4 deletions lib/Dialect/Poly/PolyOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#define LIB_DIALECT_POLY_POLYOPS_H_

#include "lib/Dialect/Poly/PolyDialect.h"
#include "lib/Dialect/Poly/PolyTraits.h"
#include "lib/Dialect/Poly/PolyTypes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project

#define GET_OP_CLASSES
#include "lib/Dialect/Poly/PolyOps.h.inc"

#endif // LIB_DIALECT_POLY_POLYOPS_H_
#endif // LIB_DIALECT_POLY_POLYOPS_H_
19 changes: 13 additions & 6 deletions lib/Dialect/Poly/PolyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@ include "PolyDialect.td"
include "PolyTypes.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

// Type constraint for poly binop arguments: polys, vectors of polys, or
// tensors of polys.
def PolyOrContainer : TypeOrContainer<Polynomial, "poly-or-container">;

class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, ElementwiseMappable, SameOperandsAndResultElementType]> {
// Inject verification that all integer-like arguments are 32-bits
def Has32BitArguments : NativeOpTrait<"Has32BitArguments"> {
let cppNamespace = "::mlir::tutorial::poly";
}

class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, ElementwiseMappable, SameOperandsAndResultType]> {
let arguments = (ins PolyOrContainer:$lhs, PolyOrContainer:$rhs);
let results = (outs PolyOrContainer:$output);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)";
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))";
let hasFolder = 1;
}

Expand All @@ -34,22 +40,23 @@ def Poly_FromTensorOp : Op<Poly_Dialect, "from_tensor", [Pure]> {
let summary = "Creates a Polynomial from integer coefficients stored in a tensor.";
let arguments = (ins TensorOf<[AnyInteger]>:$input);
let results = (outs Polynomial:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
let assemblyFormat = "$input attr-dict `:` type($input) `->` qualified(type($output))";
let hasFolder = 1;
}

def Poly_EvalOp : Op<Poly_Dialect, "eval"> {
def Poly_EvalOp : Op<Poly_Dialect, "eval", [AllTypesMatch<["point", "output"]>, Has32BitArguments]> {
let summary = "Evaluates a Polynomial at a given input value.";
let arguments = (ins Polynomial:$input, AnyInteger:$point);
let results = (outs AnyInteger:$output);
let assemblyFormat = "$input `,` $point attr-dict `:` `(` type($input) `,` type($point) `)` `->` type($output)";
let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)";
let hasVerifier = 1;
}

def Poly_ConstantOp : Op<Poly_Dialect, "constant", [Pure, ConstantLike]> {
let summary = "Define a constant polynomial via an attribute.";
let arguments = (ins AnyIntElementsAttr:$coefficients);
let results = (outs Polynomial:$output);
let assemblyFormat = "$coefficients attr-dict `:` type($output)";
let assemblyFormat = "$coefficients attr-dict `:` qualified(type($output))";
let hasFolder = 1;
}

Expand Down
28 changes: 28 additions & 0 deletions lib/Dialect/Poly/PolyTraits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef LIB_DIALECT_POLY_POLYTRAITS_H_
#define LIB_DIALECT_POLY_POLYTRAITS_H_

#include "mlir/include/mlir/IR/OpDefinition.h"

namespace mlir::tutorial::poly {

template <typename ConcreteType>
class Has32BitArguments : public OpTrait::TraitBase<ConcreteType, Has32BitArguments> {
public:
static LogicalResult verifyTrait(Operation *op) {
for (auto type : op->getOperandTypes()) {
// OK to skip non-integer operand types
if (!type.isIntOrIndex()) continue;

if (!type.isInteger(32)) {
return op->emitOpError()
<< "requires each numeric operand to be a 32-bit integer";
}
}

return success();
}
};

}

#endif // LIB_DIALECT_POLY_POLYTRAITS_H_
4 changes: 2 additions & 2 deletions tests/code_motion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ module {
%ret_val = affine.for %i = 0 to 100 iter_args(%sum_iter = %p0) -> !poly.poly<10> {
// The poly.mul should be hoisted out of the loop.
// CHECK-NOT: poly.mul
%2 = poly.mul %p0, %p1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%sum_next = poly.add %sum_iter, %2 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%2 = poly.mul %p0, %p1 : !poly.poly<10>
%sum_next = poly.add %sum_iter, %2 : !poly.poly<10>
affine.yield %sum_next : !poly.poly<10>
}

Expand Down
4 changes: 2 additions & 2 deletions tests/control_flow_sink.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ func.func @test_simple_sink(%arg0: i1) -> !poly.poly<10> {
// CHECK: scf.if
%4 = scf.if %arg0 -> (!poly.poly<10>) {
// CHECK: poly.from_tensor
%2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%2 = poly.mul %p0, %p0 : !poly.poly<10>
scf.yield %2 : !poly.poly<10>
// CHECK: else
} else {
// CHECK: poly.from_tensor
%3 = poly.mul %p1, %p1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %p1, %p1 : !poly.poly<10>
scf.yield %3 : !poly.poly<10>
}
return %4 : !poly.poly<10>
Expand Down
6 changes: 3 additions & 3 deletions tests/cse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ func.func @test_simple_cse() -> !poly.poly<10> {
// exactly one mul op
// CHECK-NEXT: poly.mul
// CHECK-NEXT: poly.add
%2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%2 = poly.mul %p0, %p0 : !poly.poly<10>
%3 = poly.mul %p0, %p0 : !poly.poly<10>
%4 = poly.add %2, %3 : !poly.poly<10>
return %4 : !poly.poly<10>
}
6 changes: 3 additions & 3 deletions tests/poly_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ func.func @test_simple() -> !poly.poly<10> {
// CHECK-NEXT: return
%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
%p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>
%2 = poly.add %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%2 = poly.add %p0, %p0 : !poly.poly<10>
%3 = poly.mul %p0, %p0 : !poly.poly<10>
%4 = poly.add %2, %3 : !poly.poly<10>
return %2 : !poly.poly<10>
}
10 changes: 4 additions & 6 deletions tests/poly_syntax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ module {
// CHECK-LABEL: test_op_syntax
func.func @test_op_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> {
// CHECK: poly.add
%0 = poly.add %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%0 = poly.add %arg0, %arg1 : !poly.poly<10>
// CHECK: poly.sub
%1 = poly.sub %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%1 = poly.sub %arg0, %arg1 : !poly.poly<10>
// CHECK: poly.mul
%2 = poly.mul %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%2 = poly.mul %arg0, %arg1 : !poly.poly<10>

%3 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
// CHECK: poly.from_tensor
Expand All @@ -27,9 +27,7 @@ module {

%7 = tensor.from_elements %arg0, %arg1 : tensor<2x!poly.poly<10>>
// CHECK: poly.add
%8 = poly.add %7, %7 : (tensor<2x!poly.poly<10>>, tensor<2x!poly.poly<10>>) -> tensor<2x!poly.poly<10>>
// CHECK: poly.add
%9 = poly.add %7, %4 : (tensor<2x!poly.poly<10>>, !poly.poly<10>) -> tensor<2x!poly.poly<10>>
%8 = poly.add %7, %7 : tensor<2x!poly.poly<10>>

// CHECK: poly.constant
%10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>
Expand Down
10 changes: 10 additions & 0 deletions tests/poly_verifier.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: tutorial-opt %s 2>%t; FileCheck %s < %t

func.func @test_invalid_evalop(%arg0: !poly.poly<10>, %cst: i64) -> i64 {
// This is a little brittle, since it matches both the error message
// emitted by Has32BitArguments as well as that of EvalOp::verify.
// I manually tested that they both fire when the input is as below.
// CHECK: to be a 32-bit integer
%0 = poly.eval %arg0, %cst : (!poly.poly<10>, i64) -> i64
return %0 : i64
}
6 changes: 3 additions & 3 deletions tests/sccp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func.func @test_poly_sccp() -> !poly.poly<10> {
// CHECK: poly.constant dense<[1, 2, 3]>
// CHECK-NOT: poly.mul
// CHECK-NOT: poly.add
%2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
%2 = poly.mul %p0, %p0 : !poly.poly<10>
%3 = poly.mul %p0, %p0 : !poly.poly<10>
%4 = poly.add %2, %3 : !poly.poly<10>
return %2 : !poly.poly<10>
}

0 comments on commit 031e5fe

Please sign in to comment.