diff --git a/tests/filecheck/dialects/fsm/fsm_invalid.mlir b/tests/filecheck/dialects/fsm/fsm_invalid.mlir index 01f8bceef3..a32518ab1a 100644 --- a/tests/filecheck/dialects/fsm/fsm_invalid.mlir +++ b/tests/filecheck/dialects/fsm/fsm_invalid.mlir @@ -141,9 +141,9 @@ "fsm.transition"() ({ "fsm.return"(%arg0) : (i1) -> () }, { - + }) {nextState = @A} : () -> () - + }) {sym_name = "A"} : () -> () }) {function_type = (i1) -> (i1), initialState = "A", sym_name = "foo", res_names = ["names"],res_attrs = [{"name"="1","type"="2"}] } : () -> () @@ -160,7 +160,7 @@ "fsm.output"() : () -> () }, { "fsm.transition"() ({ - ^bb1(%arg3: i1): + ^bb1(%arg3: i1): "fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> () "fsm.output"() : () -> () }, { @@ -179,16 +179,16 @@ "fsm.state"() ({ "fsm.output"() : () -> () - + }, { "fsm.transition"() ({ - + }, { - ^bb1(%arg3: i1): + ^bb1(%arg3: i1): "fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> () "fsm.update"(%arg1, %arg2) {variable = "v1" , value = "v2"}: (i16,i16) -> () }) {nextState = @A} : () -> () - + }) {sym_name = "A"} : () -> () }) {function_type = () -> (), initialState = "A", sym_name = "foo", res_names = ["names"],res_attrs = [{"name"="1","type"="2"}] } : () -> () @@ -203,10 +203,10 @@ }, { "fsm.transition"() ({ - + }, { }) {nextState = @A} : () -> () - + }) {sym_name = "A"} : () -> () }) {function_type = (i16) -> (i16) , initialState = "A", sym_name = "foo"} : () -> () @@ -257,7 +257,7 @@ }, { }) {nextState = @A} : () -> () }) {sym_name = "A"} : () -> () - + }) {function_type = (i16) -> (i1), initialState = "A", sym_name = "foo"} : () -> () %arg1 = "arith.constant"() {value = 0 : i16} : () -> i16 %arg2 = "arith.constant"() {value = 0 : i16} : () -> i16 @@ -328,12 +328,12 @@ }) {nextState = @C} : () -> () }) {sym_name = "C"} : () -> () }) {function_type = (i16) -> (i16), initialState = "A", sym_name = "foo"} : () -> () - + "func.func"() ({ %3 = "arith.constant"() {value = 16: i16} : () -> i16 - + %4 = "fsm.instance"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype - %1 = "arith.constant"() {value = true} : () -> i16 + %1 = "arith.constant"() {value = 0 : i16} : () -> i16 %2 = "fsm.trigger"(%1, %4) : (i16, !fsm.instancetype) -> i1 "func.return"() : () -> () }) {function_type = () -> (), sym_name = "qux"} : () -> () @@ -371,10 +371,10 @@ }) {nextState = @C} : () -> () }) {sym_name = "C"} : () -> () }) {function_type = (i16) -> (i16), initialState = "A", sym_name = "foo"} : () -> () - + "func.func"() ({ %3 = "arith.constant"() {value = 16: i16} : () -> i16 - + %4 = "fsm.instance"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype %1 = "arith.constant"() {value = true} : () -> i1 %2 = "fsm.trigger"(%1, %4) : (i1, !fsm.instancetype) -> i16 @@ -391,8 +391,8 @@ %0 = "fsm.variable"() {initValue = 0 : i16, name = "cnt"} : () -> i16 "fsm.machine"() ({ %4 = "test.op"() {machine = @foo, sym_name = "foo_inst"} : () -> !fsm.instancetype - %1 = "arith.constant"() {value = true} : () -> i16 - %2 = "fsm.trigger"(%1, %4) : (i16, !fsm.instancetype) -> i1 + %1 = "arith.constant"() {value = true} : () -> i1 + %2 = "fsm.trigger"(%1, %4) : (i1, !fsm.instancetype) -> i1 "func.return"() : () -> () }) {function_type = () -> (), sym_name = "qux"} : () -> () diff --git a/tests/filecheck/runner/factorial.mlir b/tests/filecheck/runner/factorial.mlir index 4b61111436..1cabd4474e 100644 --- a/tests/filecheck/runner/factorial.mlir +++ b/tests/filecheck/runner/factorial.mlir @@ -17,7 +17,7 @@ builtin.module { "func.return"(%ret) : (i64) -> () } func.func @main() -> index { - %zero = "arith.constant"() {"value" = 0} : () -> index + %zero = "arith.constant"() {"value" = 0 : index} : () -> index %i = "arith.constant"() {"value" = 12} : () -> i64 %fac = "func.call"(%i) {"callee" = @factorial} : (i64) -> i64 printf.print_format "factorial({})={}", %i : i64, %fac : i64 diff --git a/tests/filecheck/runner/with-wgpu/global_id_inc.mlir b/tests/filecheck/runner/with-wgpu/global_id_inc.mlir index 58d64af288..34fd9a7a82 100644 --- a/tests/filecheck/runner/with-wgpu/global_id_inc.mlir +++ b/tests/filecheck/runner/with-wgpu/global_id_inc.mlir @@ -40,7 +40,7 @@ builtin.module attributes {gpu.container_module} { %hmemref = "memref.alloc"() {"alignment" = 0 : i64, "operandSegmentSizes" = array} : () -> memref<4x4xindex> "gpu.memcpy"(%hmemref, %memref) {"operandSegmentSizes" = array} : (memref<4x4xindex>, memref<4x4xindex>) -> () printf.print_format "Result : {}", %hmemref : memref<4x4xindex> - %zero = "arith.constant"() {"value" = 0} : () -> (index) + %zero = "arith.constant"() {"value" = 0 : index} : () -> (index) "func.return"(%zero) : (index) -> () } } diff --git a/tests/filecheck/transforms/function-constant-pinning.mlir b/tests/filecheck/transforms/function-constant-pinning.mlir index 94b4910d47..3c34141cce 100644 --- a/tests/filecheck/transforms/function-constant-pinning.mlir +++ b/tests/filecheck/transforms/function-constant-pinning.mlir @@ -2,7 +2,7 @@ func.func @basic() -> i32 { - %v = "test.op"() {pin_to_constants = [0]} : () -> i32 + %v = "test.op"() {pin_to_constants = [0 : i32]} : () -> i32 func.return %v : i32 } @@ -11,7 +11,7 @@ func.func @basic() -> i32 { // CHECK-NEXT: func.func @basic() -> i32 { // CHECK-NEXT: %v = "test.op"() : () -> i32 // compare the value to the constant we want to specialize for -// CHECK-NEXT: %0 = arith.constant 0 : i64 +// CHECK-NEXT: %0 = arith.constant 0 : i32 // CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32 // CHECK-NEXT: %2 = scf.if %1 -> (i32) { // if they are equal, branch to specialized function @@ -25,7 +25,7 @@ func.func @basic() -> i32 { // specialized function here // CHECK-NEXT: func.func @basic_pinned() -> i32 { // original op is replaced by constant instantiation -// CHECK-NEXT: %v = arith.constant 0 : i64 +// CHECK-NEXT: %v = arith.constant 0 : i32 // CHECK-NEXT: func.return %v : i32 // CHECK-NEXT: } @@ -79,7 +79,7 @@ func.func @control_flow() { func.func @function_args(%arg0: memref<100xf32>) -> i32 { - %v = "test.op"() {pin_to_constants = [0]} : () -> i32 + %v = "test.op"() {pin_to_constants = [0 : i32]} : () -> i32 "test.op"(%v, %arg0) : (i32, memref<100xf32>) -> () @@ -89,7 +89,7 @@ func.func @function_args(%arg0: memref<100xf32>) -> i32 { // CHECK-NEXT: func.func @function_args(%arg0 : memref<100xf32>) -> i32 { // CHECK-NEXT: %v = "test.op"() : () -> i32 -// CHECK-NEXT: %0 = arith.constant 0 : i64 +// CHECK-NEXT: %0 = arith.constant 0 : i32 // CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32 // CHECK-NEXT: %2 = scf.if %1 -> (i32) { // make sure that we forward function args to the specialized function @@ -103,7 +103,7 @@ func.func @function_args(%arg0: memref<100xf32>) -> i32 { // CHECK-NEXT: func.return %2 : i32 // CHECK-NEXT: } // CHECK-NEXT: func.func @function_args_pinned(%arg0 : memref<100xf32>) -> i32 { -// CHECK-NEXT: %v = arith.constant 0 : i64 +// CHECK-NEXT: %v = arith.constant 0 : i32 // here the function arg is used // CHECK-NEXT: "test.op"(%v, %arg0) : (i32, memref<100xf32>) -> () // CHECK-NEXT: func.return %v : i32 @@ -155,7 +155,7 @@ func.func @control_flow_and_function_args(%arg: i32) -> i32 { func.func @specialize_multi_case() -> i32 { - %v = "test.op"() {pin_to_constants = [0, 1]} : () -> i32 + %v = "test.op"() {pin_to_constants = [0 : i32, 1 : i32]} : () -> i32 func.return %v : i32 } @@ -164,13 +164,13 @@ func.func @specialize_multi_case() -> i32 { // CHECK-NEXT: func.func @specialize_multi_case() -> i32 { // CHECK-NEXT: %v = "test.op"() : () -> i32 -// CHECK-NEXT: %0 = arith.constant 0 : i64 +// CHECK-NEXT: %0 = arith.constant 0 : i32 // CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32 // CHECK-NEXT: %2 = scf.if %1 -> (i32) { // CHECK-NEXT: %3 = func.call @specialize_multi_case_pinned_1() : () -> i32 // CHECK-NEXT: scf.yield %3 : i32 // CHECK-NEXT: } else { -// CHECK-NEXT: %4 = arith.constant 1 : i64 +// CHECK-NEXT: %4 = arith.constant 1 : i32 // CHECK-NEXT: %5 = arith.cmpi eq, %v, %4 : i32 // CHECK-NEXT: %6 = scf.if %5 -> (i32) { // CHECK-NEXT: %7 = func.call @specialize_multi_case_pinned() : () -> i32 @@ -185,8 +185,8 @@ func.func @specialize_multi_case() -> i32 { // CHECK-NEXT: func.func @specialize_multi_case_pinned_1() -> i32 { // this function still carries the old specialization check within it, but MLIR can see that // the branch is never taken, so it's completely removed. -// CHECK-NEXT: %v = arith.constant 0 : i64 -// CHECK-NEXT: %0 = arith.constant 1 : i64 +// CHECK-NEXT: %v = arith.constant 0 : i32 +// CHECK-NEXT: %0 = arith.constant 1 : i32 // CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32 // CHECK-NEXT: %2 = scf.if %1 -> (i32) { // CHECK-NEXT: %3 = func.call @specialize_multi_case_pinned() : () -> i32 @@ -197,7 +197,7 @@ func.func @specialize_multi_case() -> i32 { // CHECK-NEXT: func.return %2 : i32 // CHECK-NEXT: } // CHECK-NEXT: func.func @specialize_multi_case_pinned() -> i32 { -// CHECK-NEXT: %v = arith.constant 1 : i64 +// CHECK-NEXT: %v = arith.constant 1 : i32 // CHECK-NEXT: func.return %v : i32 // CHECK-NEXT: } diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 34e09a4241..c16d128831 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -8,6 +8,7 @@ AnyFloat, AnyFloatConstr, AnyIntegerAttr, + AnyIntegerAttrConstr, ContainerOf, DenseIntOrFPElementsAttr, Float16Type, @@ -25,8 +26,11 @@ from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag from xdsl.ir import Attribute, BitEnumAttribute, Dialect, Operation, SSAValue from xdsl.irdl import ( + AnyAttr, AnyOf, + BaseAttr, IRDLOperation, + TypedAttributeConstraint, VarConstraint, base, irdl_attr_definition, @@ -48,7 +52,6 @@ Pure, ) from xdsl.utils.exceptions import VerifyException -from xdsl.utils.isattr import isattr from xdsl.utils.str_enum import StrEnum boolLike = ContainerOf(IntegerType(1)) @@ -124,11 +127,21 @@ def __init__(self, flags: None | Sequence[IntegerOverflowFlag] | Literal["none"] @irdl_op_definition class Constant(IRDLOperation): name = "arith.constant" - result = result_def(Attribute) - value = prop_def(Attribute) + _T: ClassVar = VarConstraint("T", AnyAttr()) + result = result_def(_T) + value = prop_def( + TypedAttributeConstraint( + AnyIntegerAttrConstr + | BaseAttr[FloatAttr[AnyFloat]](FloatAttr) + | BaseAttr(DenseIntOrFPElementsAttr), + _T, + ) + ) traits = traits_def(ConstantLike(), Pure()) + assembly_format = "attr-dict $value" + @overload def __init__( self, @@ -162,31 +175,6 @@ def from_int_and_width( properties={"value": IntegerAttr(value, value_type)}, ) - def print(self, printer: Printer): - printer.print_op_attributes(self.attributes) - - printer.print(" ") - printer.print_attribute(self.value) - - @classmethod - def parse(cls: type[Constant], parser: Parser) -> Constant: - attrs = parser.parse_optional_attr_dict() - - p0 = parser.pos - value = parser.parse_attribute() - - if not isattr( - value, - base(AnyIntegerAttr) - | base(FloatAttr[AnyFloat]) - | base(DenseIntOrFPElementsAttr), - ): - parser.raise_error("Invalid constant value", p0, parser.pos) - - c = Constant(value) - c.attributes.update(attrs) - return c - _T = TypeVar("_T", bound=Attribute) diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 88b0a95617..d87152ddab 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -513,6 +513,9 @@ def parse_with_type( def print_without_type(self, printer: Printer): return printer.print(self.value.data) + def get_type(self) -> Attribute: + return self.type + @staticmethod def constr( *, diff --git a/xdsl/irdl/constraints.py b/xdsl/irdl/constraints.py index 6af84b9baa..7fcc977b67 100644 --- a/xdsl/irdl/constraints.py +++ b/xdsl/irdl/constraints.py @@ -9,7 +9,13 @@ from typing_extensions import assert_never -from xdsl.ir import Attribute, AttributeCovT, AttributeInvT, ParametrizedAttribute +from xdsl.ir import ( + Attribute, + AttributeCovT, + AttributeInvT, + ParametrizedAttribute, + TypedAttribute, +) from xdsl.utils.exceptions import VerifyException from xdsl.utils.runtime_final import is_runtime_final @@ -206,6 +212,49 @@ def extract_var(self, a: ConstraintVariableTypeT) -> ConstraintVariableType: return a +TypedAttributeCovT = TypeVar("TypedAttributeCovT", bound=TypedAttribute, covariant=True) +TypedAttributeT = TypeVar("TypedAttributeT", bound=TypedAttribute) + + +@dataclass(frozen=True) +class TypedAttributeConstraint(GenericAttrConstraint[TypedAttributeCovT]): + """ + Constrains the type of a typed attribute. + """ + + attr_constraint: GenericAttrConstraint[TypedAttributeCovT] + + type_constraint: GenericAttrConstraint[Attribute] + + def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None: + self.attr_constraint.verify(attr, constraint_context) + if not isinstance(attr, TypedAttribute): + raise VerifyException(f"attribute {attr} expected to be a TypedAttribute") + self.type_constraint.verify(attr.get_type(), constraint_context) + + @dataclass(frozen=True) + class _Extractor(VarExtractor[TypedAttributeT]): + inner: VarExtractor[Attribute] + + def extract_var(self, a: TypedAttributeT) -> ConstraintVariableType: + return self.inner.extract_var(a.get_type()) + + def get_variable_extractors(self) -> dict[str, VarExtractor[TypedAttributeCovT]]: + return merge_extractor_dicts( + self.attr_constraint.get_variable_extractors(), + { + v: self._Extractor(r) + for v, r in self.type_constraint.get_variable_extractors().items() + }, + ) + + def can_infer(self, var_constraint_names: Set[str]) -> bool: + return self.attr_constraint.can_infer(var_constraint_names) + + def infer(self, context: InferenceContext) -> TypedAttributeCovT: + return self.attr_constraint.infer(context) + + @dataclass(frozen=True) class VarConstraint(GenericAttrConstraint[AttributeCovT]): """ diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 7e62f3f0b6..b798bc0812 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -254,6 +254,19 @@ def extract_var(self, a: ParsingState) -> ConstraintVariableType: types = (types,) return self.inner.extract_var(types) + @dataclass(frozen=True) + class _AttrExtractor(VarExtractor[ParsingState]): + name: str + is_prop: bool + inner: VarExtractor[Attribute] + + def extract_var(self, a: ParsingState) -> ConstraintVariableType: + if self.is_prop: + attr = a.properties[self.name] + else: + attr = a.attributes[self.name] + return self.inner.extract_var(attr) + def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]: """ Find out which constraint variables can be inferred from the parsed attributes. @@ -275,6 +288,20 @@ def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]: for v, r in result_def.constr.get_variable_extractors().items() } ) + for prop_name, prop_def in self.op_def.properties.items(): + extractor_dicts.append( + { + v: self._AttrExtractor(prop_name, True, r) + for v, r in prop_def.constr.get_variable_extractors().items() + } + ) + for attr_name, attr_def in self.op_def.attributes.items(): + extractor_dicts.append( + { + v: self._AttrExtractor(attr_name, False, r) + for v, r in attr_def.constr.get_variable_extractors().items() + } + ) return merge_extractor_dicts(*extractor_dicts) def verify_operands(self, var_constraint_names: Set[str]):