Skip to content

Commit

Permalink
core: introduce TypedAttributeConstraint (#3318)
Browse files Browse the repository at this point in the history
Introduces a constraint on `TypedAttribute`s which allows the user to put a constraint on the type of the TypedAttribute. Also updates `arith.constant` to use this new constraint and fixes some cases where it was used incorrectly before.
  • Loading branch information
alexarice authored Nov 22, 2024
1 parent 018d17f commit d9e2fa1
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 60 deletions.
34 changes: 17 additions & 17 deletions tests/filecheck/dialects/fsm/fsm_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@
"fsm.transition"() ({
"fsm.return"(%arg0) : (i1) -> ()
}, {

}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()
}) {function_type = (i1) -> (i1), initialState = "A", sym_name = "foo", res_names = ["names"],res_attrs = [{"name"="1","type"="2"}] } : () -> ()

Expand All @@ -160,7 +160,7 @@
"fsm.output"() : () -> ()
}, {
"fsm.transition"() ({
^bb1(%arg3: i1):
^bb1(%arg3: i1):
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
"fsm.output"() : () -> ()
}, {
Expand All @@ -179,16 +179,16 @@

"fsm.state"() ({
"fsm.output"() : () -> ()

}, {
"fsm.transition"() ({

}, {
^bb1(%arg3: i1):
^bb1(%arg3: i1):
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
"fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> ()
}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()
}) {function_type = () -> (), initialState = "A", sym_name = "foo", res_names = ["names"],res_attrs = [{"name"="1","type"="2"}] } : () -> ()

Expand All @@ -203,10 +203,10 @@

}, {
"fsm.transition"() ({

}, {
}) {nextState = @A} : () -> ()

}) {sym_name = "A"} : () -> ()

}) {function_type = (i16) -> (i16) , initialState = "A", sym_name = "foo"} : () -> ()
Expand Down Expand Up @@ -257,7 +257,7 @@
}, {
}) {nextState = @A} : () -> ()
}) {sym_name = "A"} : () -> ()

}) {function_type = (i16) -> (i1), initialState = "A", sym_name = "foo"} : () -> ()
%arg1 = "arith.constant"() {value = 0 : i16} : () -> i16
%arg2 = "arith.constant"() {value = 0 : i16} : () -> i16
Expand Down Expand Up @@ -328,12 +328,12 @@
}) {nextState = @C} : () -> ()
}) {sym_name = "C"} : () -> ()
}) {function_type = (i16) -> (i16), initialState = "A", sym_name = "foo"} : () -> ()

"func.func"() ({
%3 = "arith.constant"() {value = 16: i16} : () -> i16

%4 = "fsm.instance"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i16
%1 = "arith.constant"() {value = 0 : i16} : () -> i16
%2 = "fsm.trigger"(%1, %4) : (i16, !fsm.instancetype) -> i1
"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "qux"} : () -> ()
Expand Down Expand Up @@ -371,10 +371,10 @@
}) {nextState = @C} : () -> ()
}) {sym_name = "C"} : () -> ()
}) {function_type = (i16) -> (i16), initialState = "A", sym_name = "foo"} : () -> ()

"func.func"() ({
%3 = "arith.constant"() {value = 16: i16} : () -> i16

%4 = "fsm.instance"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i1
%2 = "fsm.trigger"(%1, %4) : (i1, !fsm.instancetype) -> i16
Expand All @@ -391,8 +391,8 @@
%0 = "fsm.variable"() {initValue = 0 : i16, name = "cnt"} : () -> i16
"fsm.machine"() ({
%4 = "test.op"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype
%1 = "arith.constant"() {value = true} : () -> i16
%2 = "fsm.trigger"(%1, %4) : (i16, !fsm.instancetype) -> i1
%1 = "arith.constant"() {value = true} : () -> i1
%2 = "fsm.trigger"(%1, %4) : (i1, !fsm.instancetype) -> i1
"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "qux"} : () -> ()

Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/runner/factorial.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ builtin.module {
"func.return"(%ret) : (i64) -> ()
}
func.func @main() -> index {
%zero = "arith.constant"() {"value" = 0} : () -> index
%zero = "arith.constant"() {"value" = 0 : index} : () -> index
%i = "arith.constant"() {"value" = 12} : () -> i64
%fac = "func.call"(%i) {"callee" = @factorial} : (i64) -> i64
printf.print_format "factorial({})={}", %i : i64, %fac : i64
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/runner/with-wgpu/global_id_inc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ builtin.module attributes {gpu.container_module} {
%hmemref = "memref.alloc"() {"alignment" = 0 : i64, "operandSegmentSizes" = array<i32: 0, 0>} : () -> memref<4x4xindex>
"gpu.memcpy"(%hmemref, %memref) {"operandSegmentSizes" = array<i32: 0, 1, 1>} : (memref<4x4xindex>, memref<4x4xindex>) -> ()
printf.print_format "Result : {}", %hmemref : memref<4x4xindex>
%zero = "arith.constant"() {"value" = 0} : () -> (index)
%zero = "arith.constant"() {"value" = 0 : index} : () -> (index)
"func.return"(%zero) : (index) -> ()
}
}
Expand Down
24 changes: 12 additions & 12 deletions tests/filecheck/transforms/function-constant-pinning.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


func.func @basic() -> i32 {
%v = "test.op"() {pin_to_constants = [0]} : () -> i32
%v = "test.op"() {pin_to_constants = [0 : i32]} : () -> i32
func.return %v : i32
}

Expand All @@ -11,7 +11,7 @@ func.func @basic() -> i32 {
// CHECK-NEXT: func.func @basic() -> i32 {
// CHECK-NEXT: %v = "test.op"() : () -> i32
// compare the value to the constant we want to specialize for
// CHECK-NEXT: %0 = arith.constant 0 : i64
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32
// CHECK-NEXT: %2 = scf.if %1 -> (i32) {
// if they are equal, branch to specialized function
Expand All @@ -25,7 +25,7 @@ func.func @basic() -> i32 {
// specialized function here
// CHECK-NEXT: func.func @basic_pinned() -> i32 {
// original op is replaced by constant instantiation
// CHECK-NEXT: %v = arith.constant 0 : i64
// CHECK-NEXT: %v = arith.constant 0 : i32
// CHECK-NEXT: func.return %v : i32
// CHECK-NEXT: }

Expand Down Expand Up @@ -79,7 +79,7 @@ func.func @control_flow() {


func.func @function_args(%arg0: memref<100xf32>) -> i32 {
%v = "test.op"() {pin_to_constants = [0]} : () -> i32
%v = "test.op"() {pin_to_constants = [0 : i32]} : () -> i32

"test.op"(%v, %arg0) : (i32, memref<100xf32>) -> ()

Expand All @@ -89,7 +89,7 @@ func.func @function_args(%arg0: memref<100xf32>) -> i32 {

// CHECK-NEXT: func.func @function_args(%arg0 : memref<100xf32>) -> i32 {
// CHECK-NEXT: %v = "test.op"() : () -> i32
// CHECK-NEXT: %0 = arith.constant 0 : i64
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32
// CHECK-NEXT: %2 = scf.if %1 -> (i32) {
// make sure that we forward function args to the specialized function
Expand All @@ -103,7 +103,7 @@ func.func @function_args(%arg0: memref<100xf32>) -> i32 {
// CHECK-NEXT: func.return %2 : i32
// CHECK-NEXT: }
// CHECK-NEXT: func.func @function_args_pinned(%arg0 : memref<100xf32>) -> i32 {
// CHECK-NEXT: %v = arith.constant 0 : i64
// CHECK-NEXT: %v = arith.constant 0 : i32
// here the function arg is used
// CHECK-NEXT: "test.op"(%v, %arg0) : (i32, memref<100xf32>) -> ()
// CHECK-NEXT: func.return %v : i32
Expand Down Expand Up @@ -155,7 +155,7 @@ func.func @control_flow_and_function_args(%arg: i32) -> i32 {


func.func @specialize_multi_case() -> i32 {
%v = "test.op"() {pin_to_constants = [0, 1]} : () -> i32
%v = "test.op"() {pin_to_constants = [0 : i32, 1 : i32]} : () -> i32
func.return %v : i32
}

Expand All @@ -164,13 +164,13 @@ func.func @specialize_multi_case() -> i32 {

// CHECK-NEXT: func.func @specialize_multi_case() -> i32 {
// CHECK-NEXT: %v = "test.op"() : () -> i32
// CHECK-NEXT: %0 = arith.constant 0 : i64
// CHECK-NEXT: %0 = arith.constant 0 : i32
// CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32
// CHECK-NEXT: %2 = scf.if %1 -> (i32) {
// CHECK-NEXT: %3 = func.call @specialize_multi_case_pinned_1() : () -> i32
// CHECK-NEXT: scf.yield %3 : i32
// CHECK-NEXT: } else {
// CHECK-NEXT: %4 = arith.constant 1 : i64
// CHECK-NEXT: %4 = arith.constant 1 : i32
// CHECK-NEXT: %5 = arith.cmpi eq, %v, %4 : i32
// CHECK-NEXT: %6 = scf.if %5 -> (i32) {
// CHECK-NEXT: %7 = func.call @specialize_multi_case_pinned() : () -> i32
Expand All @@ -185,8 +185,8 @@ func.func @specialize_multi_case() -> i32 {
// CHECK-NEXT: func.func @specialize_multi_case_pinned_1() -> i32 {
// this function still carries the old specialization check within it, but MLIR can see that
// the branch is never taken, so it's completely removed.
// CHECK-NEXT: %v = arith.constant 0 : i64
// CHECK-NEXT: %0 = arith.constant 1 : i64
// CHECK-NEXT: %v = arith.constant 0 : i32
// CHECK-NEXT: %0 = arith.constant 1 : i32
// CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32
// CHECK-NEXT: %2 = scf.if %1 -> (i32) {
// CHECK-NEXT: %3 = func.call @specialize_multi_case_pinned() : () -> i32
Expand All @@ -197,7 +197,7 @@ func.func @specialize_multi_case() -> i32 {
// CHECK-NEXT: func.return %2 : i32
// CHECK-NEXT: }
// CHECK-NEXT: func.func @specialize_multi_case_pinned() -> i32 {
// CHECK-NEXT: %v = arith.constant 1 : i64
// CHECK-NEXT: %v = arith.constant 1 : i32
// CHECK-NEXT: func.return %v : i32
// CHECK-NEXT: }

Expand Down
44 changes: 16 additions & 28 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AnyFloat,
AnyFloatConstr,
AnyIntegerAttr,
AnyIntegerAttrConstr,
ContainerOf,
DenseIntOrFPElementsAttr,
Float16Type,
Expand All @@ -25,8 +26,11 @@
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyAttr,
AnyOf,
BaseAttr,
IRDLOperation,
TypedAttributeConstraint,
VarConstraint,
base,
irdl_attr_definition,
Expand All @@ -48,7 +52,6 @@
Pure,
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.isattr import isattr
from xdsl.utils.str_enum import StrEnum

boolLike = ContainerOf(IntegerType(1))
Expand Down Expand Up @@ -124,11 +127,21 @@ def __init__(self, flags: None | Sequence[IntegerOverflowFlag] | Literal["none"]
@irdl_op_definition
class Constant(IRDLOperation):
name = "arith.constant"
result = result_def(Attribute)
value = prop_def(Attribute)
_T: ClassVar = VarConstraint("T", AnyAttr())
result = result_def(_T)
value = prop_def(
TypedAttributeConstraint(
AnyIntegerAttrConstr
| BaseAttr[FloatAttr[AnyFloat]](FloatAttr)
| BaseAttr(DenseIntOrFPElementsAttr),
_T,
)
)

traits = traits_def(ConstantLike(), Pure())

assembly_format = "attr-dict $value"

@overload
def __init__(
self,
Expand Down Expand Up @@ -162,31 +175,6 @@ def from_int_and_width(
properties={"value": IntegerAttr(value, value_type)},
)

def print(self, printer: Printer):
printer.print_op_attributes(self.attributes)

printer.print(" ")
printer.print_attribute(self.value)

@classmethod
def parse(cls: type[Constant], parser: Parser) -> Constant:
attrs = parser.parse_optional_attr_dict()

p0 = parser.pos
value = parser.parse_attribute()

if not isattr(
value,
base(AnyIntegerAttr)
| base(FloatAttr[AnyFloat])
| base(DenseIntOrFPElementsAttr),
):
parser.raise_error("Invalid constant value", p0, parser.pos)

c = Constant(value)
c.attributes.update(attrs)
return c


_T = TypeVar("_T", bound=Attribute)

Expand Down
3 changes: 3 additions & 0 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ def parse_with_type(
def print_without_type(self, printer: Printer):
return printer.print(self.value.data)

def get_type(self) -> Attribute:
return self.type

@staticmethod
def constr(
*,
Expand Down
51 changes: 50 additions & 1 deletion xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@

from typing_extensions import assert_never

from xdsl.ir import Attribute, AttributeCovT, AttributeInvT, ParametrizedAttribute
from xdsl.ir import (
Attribute,
AttributeCovT,
AttributeInvT,
ParametrizedAttribute,
TypedAttribute,
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.runtime_final import is_runtime_final

Expand Down Expand Up @@ -206,6 +212,49 @@ def extract_var(self, a: ConstraintVariableTypeT) -> ConstraintVariableType:
return a


TypedAttributeCovT = TypeVar("TypedAttributeCovT", bound=TypedAttribute, covariant=True)
TypedAttributeT = TypeVar("TypedAttributeT", bound=TypedAttribute)


@dataclass(frozen=True)
class TypedAttributeConstraint(GenericAttrConstraint[TypedAttributeCovT]):
"""
Constrains the type of a typed attribute.
"""

attr_constraint: GenericAttrConstraint[TypedAttributeCovT]

type_constraint: GenericAttrConstraint[Attribute]

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
self.attr_constraint.verify(attr, constraint_context)
if not isinstance(attr, TypedAttribute):
raise VerifyException(f"attribute {attr} expected to be a TypedAttribute")
self.type_constraint.verify(attr.get_type(), constraint_context)

@dataclass(frozen=True)
class _Extractor(VarExtractor[TypedAttributeT]):
inner: VarExtractor[Attribute]

def extract_var(self, a: TypedAttributeT) -> ConstraintVariableType:
return self.inner.extract_var(a.get_type())

def get_variable_extractors(self) -> dict[str, VarExtractor[TypedAttributeCovT]]:
return merge_extractor_dicts(
self.attr_constraint.get_variable_extractors(),
{
v: self._Extractor(r)
for v, r in self.type_constraint.get_variable_extractors().items()
},
)

def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.attr_constraint.can_infer(var_constraint_names)

def infer(self, context: InferenceContext) -> TypedAttributeCovT:
return self.attr_constraint.infer(context)


@dataclass(frozen=True)
class VarConstraint(GenericAttrConstraint[AttributeCovT]):
"""
Expand Down
Loading

0 comments on commit d9e2fa1

Please sign in to comment.