From d365db8767a00d89e51f79c86f275cd6548cac27 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Thu, 28 Nov 2024 22:51:13 +0000 Subject: [PATCH] core: Fix extractor logic for optional and default-valued properties (#3525) The extractor used in the assembly parser assumed that all properties/attributes were present. We should instead only propagate the extractors from the property constraint if the property is non-optional or there is a default value. The default_value also needs to be given to the extractor for when the property is not present. Adds 2 tests to demonstrate this. --- .../irdl/test_declarative_assembly_format.py | 69 +++++++++++++++++++ .../declarative_assembly_format_parser.py | 21 ++++-- 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index aade6fbfd5..4a22569f14 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -11,6 +11,7 @@ from xdsl.dialects import test from xdsl.dialects.builtin import ( I32, + AnyIntegerAttrConstr, BoolAttr, Float64Type, FloatAttr, @@ -41,6 +42,7 @@ ParsePropInAttrDict, RangeOf, RangeVarConstraint, + TypedAttributeConstraint, VarConstraint, VarOperand, VarOpResult, @@ -2637,6 +2639,73 @@ def test_renamed_optional_prop(program: str, output: str, generic: str): assert generic == stream.getvalue() +@pytest.mark.parametrize( + "program, generic", + [ + ( + "test.opt_constant : ()", + '"test.opt_constant"() : () -> ()', + ), + ( + "%0 = test.opt_constant value 1 : i32 : (i32)", + '%0 = "test.opt_constant"() <{"value" = 1 : i32}> : () -> (i32)', + ), + ], +) +def test_optional_property_with_extractor(program: str, generic: str): + @irdl_op_definition + class OptConstantOp(IRDLOperation): + name = "test.opt_constant" + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = opt_prop_def(TypedAttributeConstraint(AnyIntegerAttrConstr, T)) + + res = opt_result_def(T) + + assembly_format = "(`value` $value^)? attr-dict `:` `(` type($res) `)`" + + ctx = MLContext() + ctx.load_op(OptConstantOp) + + check_roundtrip(program, ctx) + check_equivalence(program, generic, ctx) + + +@pytest.mark.parametrize( + "program, generic", + [ + ( + "%0 = test.default_constant", + '%0 = "test.default_constant"() <{"value" = true}> : () -> (i1)', + ), + ( + "%0 = test.default_constant value 2 : i32", + '%0 = "test.default_constant"() <{"value" = 2 : i32}> : () -> (i32)', + ), + ], +) +def test_default_property_with_extractor(program: str, generic: str): + @irdl_op_definition + class DefaultConstantOp(IRDLOperation): + name = "test.default_constant" + T: ClassVar = VarConstraint("T", AnyAttr()) + + value = prop_def( + TypedAttributeConstraint(AnyIntegerAttrConstr, T), + default_value=BoolAttr.from_bool(True), + ) + + res = result_def(T) + + assembly_format = "(`value` $value^)? attr-dict" + + ctx = MLContext() + ctx.load_op(DefaultConstantOp) + + check_roundtrip(program, ctx) + check_equivalence(program, generic, ctx) + + ################################################################################ # Extractors # ################################################################################ diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 4250ff6b55..e11f085c46 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -247,15 +247,24 @@ def extract_var(self, a: ParsingState) -> ConstraintVariableType: @dataclass(frozen=True) class _AttrExtractor(VarExtractor[ParsingState]): + """ + Extracts constraint variables from the attributes/properties of an operation. + If the default_value field is None, then the attribute/property must always be + present (which is only the case for non-optional attributes/properties with no + default value). + """ + name: str is_prop: bool inner: VarExtractor[Attribute] + default_value: Attribute | None def extract_var(self, a: ParsingState) -> ConstraintVariableType: if self.is_prop: - attr = a.properties[self.name] + attr = a.properties.get(self.name, self.default_value) else: - attr = a.attributes[self.name] + attr = a.attributes.get(self.name, self.default_value) + assert attr is not None return self.inner.extract_var(attr) def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]: @@ -280,16 +289,20 @@ def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]: } ) for prop_name, prop_def in self.op_def.properties.items(): + if isinstance(prop_def, OptionalDef) and prop_def.default_value is None: + continue extractor_dicts.append( { - v: self._AttrExtractor(prop_name, True, r) + v: self._AttrExtractor(prop_name, True, r, prop_def.default_value) for v, r in prop_def.constr.get_variable_extractors().items() } ) for attr_name, attr_def in self.op_def.attributes.items(): + if isinstance(attr_def, OptionalDef) and attr_def.default_value is None: + continue extractor_dicts.append( { - v: self._AttrExtractor(attr_name, False, r) + v: self._AttrExtractor(attr_name, False, r, attr_def.default_value) for v, r in attr_def.constr.get_variable_extractors().items() } )