Skip to content

Commit

Permalink
core: Fix extractor logic for optional and default-valued properties (#…
Browse files Browse the repository at this point in the history
…3525)

The extractor used in the assembly parser assumed that all
properties/attributes were present. We should instead only propagate the
extractors from the property constraint if the property is non-optional
or there is a default value. The default_value also needs to be given to
the extractor for when the property is not present.

Adds 2 tests to demonstrate this.
  • Loading branch information
alexarice authored Nov 28, 2024
1 parent c624080 commit d365db8
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 4 deletions.
69 changes: 69 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xdsl.dialects import test
from xdsl.dialects.builtin import (
I32,
AnyIntegerAttrConstr,
BoolAttr,
Float64Type,
FloatAttr,
Expand Down Expand Up @@ -41,6 +42,7 @@
ParsePropInAttrDict,
RangeOf,
RangeVarConstraint,
TypedAttributeConstraint,
VarConstraint,
VarOperand,
VarOpResult,
Expand Down Expand Up @@ -2637,6 +2639,73 @@ def test_renamed_optional_prop(program: str, output: str, generic: str):
assert generic == stream.getvalue()


@pytest.mark.parametrize(
"program, generic",
[
(
"test.opt_constant : ()",
'"test.opt_constant"() : () -> ()',
),
(
"%0 = test.opt_constant value 1 : i32 : (i32)",
'%0 = "test.opt_constant"() <{"value" = 1 : i32}> : () -> (i32)',
),
],
)
def test_optional_property_with_extractor(program: str, generic: str):
@irdl_op_definition
class OptConstantOp(IRDLOperation):
name = "test.opt_constant"
T: ClassVar = VarConstraint("T", AnyAttr())

value = opt_prop_def(TypedAttributeConstraint(AnyIntegerAttrConstr, T))

res = opt_result_def(T)

assembly_format = "(`value` $value^)? attr-dict `:` `(` type($res) `)`"

ctx = MLContext()
ctx.load_op(OptConstantOp)

check_roundtrip(program, ctx)
check_equivalence(program, generic, ctx)


@pytest.mark.parametrize(
"program, generic",
[
(
"%0 = test.default_constant",
'%0 = "test.default_constant"() <{"value" = true}> : () -> (i1)',
),
(
"%0 = test.default_constant value 2 : i32",
'%0 = "test.default_constant"() <{"value" = 2 : i32}> : () -> (i32)',
),
],
)
def test_default_property_with_extractor(program: str, generic: str):
@irdl_op_definition
class DefaultConstantOp(IRDLOperation):
name = "test.default_constant"
T: ClassVar = VarConstraint("T", AnyAttr())

value = prop_def(
TypedAttributeConstraint(AnyIntegerAttrConstr, T),
default_value=BoolAttr.from_bool(True),
)

res = result_def(T)

assembly_format = "(`value` $value^)? attr-dict"

ctx = MLContext()
ctx.load_op(DefaultConstantOp)

check_roundtrip(program, ctx)
check_equivalence(program, generic, ctx)


################################################################################
# Extractors #
################################################################################
Expand Down
21 changes: 17 additions & 4 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,24 @@ def extract_var(self, a: ParsingState) -> ConstraintVariableType:

@dataclass(frozen=True)
class _AttrExtractor(VarExtractor[ParsingState]):
"""
Extracts constraint variables from the attributes/properties of an operation.
If the default_value field is None, then the attribute/property must always be
present (which is only the case for non-optional attributes/properties with no
default value).
"""

name: str
is_prop: bool
inner: VarExtractor[Attribute]
default_value: Attribute | None

def extract_var(self, a: ParsingState) -> ConstraintVariableType:
if self.is_prop:
attr = a.properties[self.name]
attr = a.properties.get(self.name, self.default_value)
else:
attr = a.attributes[self.name]
attr = a.attributes.get(self.name, self.default_value)
assert attr is not None
return self.inner.extract_var(attr)

def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]:
Expand All @@ -280,16 +289,20 @@ def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]:
}
)
for prop_name, prop_def in self.op_def.properties.items():
if isinstance(prop_def, OptionalDef) and prop_def.default_value is None:
continue
extractor_dicts.append(
{
v: self._AttrExtractor(prop_name, True, r)
v: self._AttrExtractor(prop_name, True, r, prop_def.default_value)
for v, r in prop_def.constr.get_variable_extractors().items()
}
)
for attr_name, attr_def in self.op_def.attributes.items():
if isinstance(attr_def, OptionalDef) and attr_def.default_value is None:
continue
extractor_dicts.append(
{
v: self._AttrExtractor(attr_name, False, r)
v: self._AttrExtractor(attr_name, False, r, attr_def.default_value)
for v, r in attr_def.constr.get_variable_extractors().items()
}
)
Expand Down

0 comments on commit d365db8

Please sign in to comment.