diff --git a/lib/Dialect/Poly/BUILD b/lib/Dialect/Poly/BUILD index c1c36f7..9dfb6d4 100644 --- a/lib/Dialect/Poly/BUILD +++ b/lib/Dialect/Poly/BUILD @@ -13,9 +13,9 @@ td_library( ], includes = ["@heir//include"], deps = [ - # the base mlir target for defining operations and dialects in tablegen - "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", ], ) @@ -89,6 +89,7 @@ cc_library( hdrs = [ "PolyDialect.h", "PolyOps.h", + "PolyTraits.h", "PolyTypes.h", ], deps = [ @@ -97,6 +98,7 @@ cc_library( ":types_inc_gen", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Support", ], ) diff --git a/lib/Dialect/Poly/PolyOps.cpp b/lib/Dialect/Poly/PolyOps.cpp index b54b117..b92c58a 100644 --- a/lib/Dialect/Poly/PolyOps.cpp +++ b/lib/Dialect/Poly/PolyOps.cpp @@ -61,6 +61,12 @@ OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) { return dyn_cast(adaptor.getInput()); } +LogicalResult EvalOp::verify() { + return getPoint().getType().isSignlessInteger(32) + ? success() + : emitOpError("argument point must be a 32-bit integer"); +} + } // namespace poly } // namespace tutorial } // namespace mlir diff --git a/lib/Dialect/Poly/PolyOps.h b/lib/Dialect/Poly/PolyOps.h index 5ebeebd..c562244 100644 --- a/lib/Dialect/Poly/PolyOps.h +++ b/lib/Dialect/Poly/PolyOps.h @@ -2,12 +2,14 @@ #define LIB_DIALECT_POLY_POLYOPS_H_ #include "lib/Dialect/Poly/PolyDialect.h" +#include "lib/Dialect/Poly/PolyTraits.h" #include "lib/Dialect/Poly/PolyTypes.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project #define GET_OP_CLASSES #include "lib/Dialect/Poly/PolyOps.h.inc" -#endif // LIB_DIALECT_POLY_POLYOPS_H_ +#endif // LIB_DIALECT_POLY_POLYOPS_H_ diff --git a/lib/Dialect/Poly/PolyOps.td b/lib/Dialect/Poly/PolyOps.td index f5e1dcd..520ab09 100644 --- a/lib/Dialect/Poly/PolyOps.td +++ b/lib/Dialect/Poly/PolyOps.td @@ -5,16 +5,22 @@ include "PolyDialect.td" include "PolyTypes.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Type constraint for poly binop arguments: polys, vectors of polys, or // tensors of polys. def PolyOrContainer : TypeOrContainer; -class Poly_BinOp : Op { +// Inject verification that all integer-like arguments are 32-bits +def Has32BitArguments : NativeOpTrait<"Has32BitArguments"> { + let cppNamespace = "::mlir::tutorial::poly"; +} + +class Poly_BinOp : Op { let arguments = (ins PolyOrContainer:$lhs, PolyOrContainer:$rhs); let results = (outs PolyOrContainer:$output); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)"; + let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))"; let hasFolder = 1; } @@ -34,22 +40,23 @@ def Poly_FromTensorOp : Op { let summary = "Creates a Polynomial from integer coefficients stored in a tensor."; let arguments = (ins TensorOf<[AnyInteger]>:$input); let results = (outs Polynomial:$output); - let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; + let assemblyFormat = "$input attr-dict `:` type($input) `->` qualified(type($output))"; let hasFolder = 1; } -def Poly_EvalOp : Op { +def Poly_EvalOp : Op, Has32BitArguments]> { let summary = "Evaluates a Polynomial at a given input value."; let arguments = (ins Polynomial:$input, AnyInteger:$point); let results = (outs AnyInteger:$output); - let assemblyFormat = "$input `,` $point attr-dict `:` `(` type($input) `,` type($point) `)` `->` type($output)"; + let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)"; + let hasVerifier = 1; } def Poly_ConstantOp : Op { let summary = "Define a constant polynomial via an attribute."; let arguments = (ins AnyIntElementsAttr:$coefficients); let results = (outs Polynomial:$output); - let assemblyFormat = "$coefficients attr-dict `:` type($output)"; + let assemblyFormat = "$coefficients attr-dict `:` qualified(type($output))"; let hasFolder = 1; } diff --git a/lib/Dialect/Poly/PolyTraits.h b/lib/Dialect/Poly/PolyTraits.h new file mode 100644 index 0000000..d220f1c --- /dev/null +++ b/lib/Dialect/Poly/PolyTraits.h @@ -0,0 +1,28 @@ +#ifndef LIB_DIALECT_POLY_POLYTRAITS_H_ +#define LIB_DIALECT_POLY_POLYTRAITS_H_ + +#include "mlir/include/mlir/IR/OpDefinition.h" + +namespace mlir::tutorial::poly { + +template +class Has32BitArguments : public OpTrait::TraitBase { + public: + static LogicalResult verifyTrait(Operation *op) { + for (auto type : op->getOperandTypes()) { + // OK to skip non-integer operand types + if (!type.isIntOrIndex()) continue; + + if (!type.isInteger(32)) { + return op->emitOpError() + << "requires each numeric operand to be a 32-bit integer"; + } + } + + return success(); + } +}; + +} + +#endif // LIB_DIALECT_POLY_POLYTRAITS_H_ diff --git a/tests/code_motion.mlir b/tests/code_motion.mlir index 3fc8b1a..51004a3 100644 --- a/tests/code_motion.mlir +++ b/tests/code_motion.mlir @@ -15,8 +15,8 @@ module { %ret_val = affine.for %i = 0 to 100 iter_args(%sum_iter = %p0) -> !poly.poly<10> { // The poly.mul should be hoisted out of the loop. // CHECK-NOT: poly.mul - %2 = poly.mul %p0, %p1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> - %sum_next = poly.add %sum_iter, %2 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %2 = poly.mul %p0, %p1 : !poly.poly<10> + %sum_next = poly.add %sum_iter, %2 : !poly.poly<10> affine.yield %sum_next : !poly.poly<10> } diff --git a/tests/control_flow_sink.mlir b/tests/control_flow_sink.mlir index 4354b48..d90fb68 100644 --- a/tests/control_flow_sink.mlir +++ b/tests/control_flow_sink.mlir @@ -12,12 +12,12 @@ func.func @test_simple_sink(%arg0: i1) -> !poly.poly<10> { // CHECK: scf.if %4 = scf.if %arg0 -> (!poly.poly<10>) { // CHECK: poly.from_tensor - %2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %2 = poly.mul %p0, %p0 : !poly.poly<10> scf.yield %2 : !poly.poly<10> // CHECK: else } else { // CHECK: poly.from_tensor - %3 = poly.mul %p1, %p1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %3 = poly.mul %p1, %p1 : !poly.poly<10> scf.yield %3 : !poly.poly<10> } return %4 : !poly.poly<10> diff --git a/tests/cse.mlir b/tests/cse.mlir index 80d23f2..899e402 100644 --- a/tests/cse.mlir +++ b/tests/cse.mlir @@ -8,8 +8,8 @@ func.func @test_simple_cse() -> !poly.poly<10> { // exactly one mul op // CHECK-NEXT: poly.mul // CHECK-NEXT: poly.add - %2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> - %3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> - %4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %2 = poly.mul %p0, %p0 : !poly.poly<10> + %3 = poly.mul %p0, %p0 : !poly.poly<10> + %4 = poly.add %2, %3 : !poly.poly<10> return %4 : !poly.poly<10> } diff --git a/tests/poly_canonicalize.mlir b/tests/poly_canonicalize.mlir index 8d8636a..0b219ab 100644 --- a/tests/poly_canonicalize.mlir +++ b/tests/poly_canonicalize.mlir @@ -6,8 +6,8 @@ func.func @test_simple() -> !poly.poly<10> { // CHECK-NEXT: return %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> %p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10> - %2 = poly.add %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> - %3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> - %4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %2 = poly.add %p0, %p0 : !poly.poly<10> + %3 = poly.mul %p0, %p0 : !poly.poly<10> + %4 = poly.add %2, %3 : !poly.poly<10> return %2 : !poly.poly<10> } diff --git a/tests/poly_syntax.mlir b/tests/poly_syntax.mlir index 0eb66b3..8ab6db8 100644 --- a/tests/poly_syntax.mlir +++ b/tests/poly_syntax.mlir @@ -11,11 +11,11 @@ module { // CHECK-LABEL: test_op_syntax func.func @test_op_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> { // CHECK: poly.add - %0 = poly.add %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %0 = poly.add %arg0, %arg1 : !poly.poly<10> // CHECK: poly.sub - %1 = poly.sub %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %1 = poly.sub %arg0, %arg1 : !poly.poly<10> // CHECK: poly.mul - %2 = poly.mul %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %2 = poly.mul %arg0, %arg1 : !poly.poly<10> %3 = arith.constant dense<[1, 2, 3]> : tensor<3xi32> // CHECK: poly.from_tensor @@ -27,9 +27,7 @@ module { %7 = tensor.from_elements %arg0, %arg1 : tensor<2x!poly.poly<10>> // CHECK: poly.add - %8 = poly.add %7, %7 : (tensor<2x!poly.poly<10>>, tensor<2x!poly.poly<10>>) -> tensor<2x!poly.poly<10>> - // CHECK: poly.add - %9 = poly.add %7, %4 : (tensor<2x!poly.poly<10>>, !poly.poly<10>) -> tensor<2x!poly.poly<10>> + %8 = poly.add %7, %7 : tensor<2x!poly.poly<10>> // CHECK: poly.constant %10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10> diff --git a/tests/poly_verifier.mlir b/tests/poly_verifier.mlir new file mode 100644 index 0000000..fadca35 --- /dev/null +++ b/tests/poly_verifier.mlir @@ -0,0 +1,10 @@ +// RUN: tutorial-opt %s 2>%t; FileCheck %s < %t + +func.func @test_invalid_evalop(%arg0: !poly.poly<10>, %cst: i64) -> i64 { + // This is a little brittle, since it matches both the error message + // emitted by Has32BitArguments as well as that of EvalOp::verify. + // I manually tested that they both fire when the input is as below. + // CHECK: to be a 32-bit integer + %0 = poly.eval %arg0, %cst : (!poly.poly<10>, i64) -> i64 + return %0 : i64 +} diff --git a/tests/sccp.mlir b/tests/sccp.mlir index 39b3a24..2243119 100644 --- a/tests/sccp.mlir +++ b/tests/sccp.mlir @@ -28,8 +28,8 @@ func.func @test_poly_sccp() -> !poly.poly<10> { // CHECK: poly.constant dense<[1, 2, 3]> // CHECK-NOT: poly.mul // CHECK-NOT: poly.add - %2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> - %3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> - %4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10> + %2 = poly.mul %p0, %p0 : !poly.poly<10> + %3 = poly.mul %p0, %p0 : !poly.poly<10> + %4 = poly.add %2, %3 : !poly.poly<10> return %2 : !poly.poly<10> }