Skip to content

Commit

Permalink
API: add InferenceContext (#3500)
Browse files Browse the repository at this point in the history
A wrapper class to encapsulate the inference context. I have a feeling
that we might want to split out the inference context into range and
non-range variable assignments, this will hide the change from the
signature of the `infer` method.
  • Loading branch information
superlopuh authored Nov 22, 2024
1 parent 4e09b2d commit 150b84b
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 43 deletions.
5 changes: 3 additions & 2 deletions tests/dialects/test_bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from xdsl.ir import Attribute
from xdsl.irdl import (
EqAttrConstraint,
InferenceContext,
IRDLOperation,
VarConstraint,
irdl_op_definition,
Expand All @@ -39,13 +40,13 @@ def test_tensor_from_memref_inference():
EqAttrConstraint(MemRefType(f64, [10, 20, 30]))
)
assert constr2.can_infer(set())
assert constr2.infer({}) == TensorType(f64, [10, 20, 30])
assert constr2.infer(InferenceContext()) == TensorType(f64, [10, 20, 30])

constr3 = TensorFromMemrefConstraint(
EqAttrConstraint(UnrankedMemrefType.from_type(f64))
)
assert constr3.can_infer(set())
assert constr3.infer({}) == UnrankedTensorType(f64)
assert constr3.infer(InferenceContext()) == UnrankedTensorType(f64)


@irdl_op_definition
Expand Down
7 changes: 4 additions & 3 deletions tests/irdl/test_attr_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AttrConstraint,
BaseAttr,
EqAttrConstraint,
InferenceContext,
ParamAttrConstraint,
ParameterDef,
VarConstraint,
Expand Down Expand Up @@ -77,7 +78,7 @@ class WrapAttr(BaseWrapAttr): ...
)

assert constr.can_infer(set())
assert constr.infer({}) == WrapAttr((StringAttr("Hello"),))
assert constr.infer(InferenceContext()) == WrapAttr((StringAttr("Hello"),))

var_constr = ParamAttrConstraint(
WrapAttr,
Expand All @@ -92,7 +93,7 @@ class WrapAttr(BaseWrapAttr): ...
)

assert var_constr.can_infer({"T"})
assert var_constr.infer({"T": StringAttr("Hello")}) == WrapAttr(
assert var_constr.infer(InferenceContext({"T": StringAttr("Hello")})) == WrapAttr(
(StringAttr("Hello"),)
)

Expand Down Expand Up @@ -127,7 +128,7 @@ class NoParamAttr(BaseNoParamAttr): ...
constr = BaseAttr(NoParamAttr)

assert constr.can_infer(set())
assert constr.infer({}) == NoParamAttr()
assert constr.infer(InferenceContext()) == NoParamAttr()

base_constr = BaseAttr(BaseNoParamAttr)
assert not base_constr.can_infer(set())
Expand Down
6 changes: 3 additions & 3 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from xdsl.irdl import (
AttrSizedOperandSegments,
ConstraintContext,
ConstraintVariableType,
GenericAttrConstraint,
InferenceContext,
IRDLOperation,
VarConstraint,
irdl_op_definition,
Expand Down Expand Up @@ -53,9 +53,9 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.memref_constraint.can_infer(var_constraint_names)

def infer(
self, variables: dict[str, ConstraintVariableType]
self, context: InferenceContext
) -> TensorType[Attribute] | UnrankedTensorType[Attribute]:
memref_type = self.memref_constraint.infer(variables)
memref_type = self.memref_constraint.infer(context)
if isinstance(memref_type, MemRefType):
return TensorType(memref_type.element_type, memref_type.shape)
return UnrankedTensorType(memref_type.element_type)
Expand Down
59 changes: 28 additions & 31 deletions xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def update(self, other: ConstraintContext):
Possible types that a constraint variable can have.
"""


@dataclass
class InferenceContext:
variables: dict[str, ConstraintVariableType] = field(default_factory=dict)
"""
A mapping from variable names to the inferred attribute or attribute sequence.
"""


_T = TypeVar("_T")


Expand Down Expand Up @@ -156,7 +165,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
# By default, we cannot infer anything.
return False

def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT:
def infer(self, context: InferenceContext) -> AttributeCovT:
"""
Infer the attribute given the the values for all variables.
Expand Down Expand Up @@ -228,8 +237,8 @@ def verify(
def get_variable_extractors(self) -> dict[str, VarExtractor[AttributeCovT]]:
return {self.name: IdExtractor()}

def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT:
v = variables[self.name]
def infer(self, context: InferenceContext) -> AttributeCovT:
v = context.variables[self.name]
return cast(AttributeCovT, v)

def can_infer(self, var_constraint_names: Set[str]) -> bool:
Expand Down Expand Up @@ -272,7 +281,7 @@ def verify(
def can_infer(self, var_constraint_names: Set[str]) -> bool:
return True

def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT:
def infer(self, context: InferenceContext) -> AttributeCovT:
return self.attr

def get_unique_base(self) -> type[Attribute] | None:
Expand Down Expand Up @@ -303,7 +312,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
and not self.attr.get_irdl_definition().parameters
)

def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT:
def infer(self, context: InferenceContext) -> AttributeCovT:
assert issubclass(self.attr, ParametrizedAttribute)
attr = self.attr.new(())
return attr
Expand Down Expand Up @@ -439,10 +448,10 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
constr.can_infer(var_constraint_names) for constr in self.attr_constrs
)

def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT:
def infer(self, context: InferenceContext) -> AttributeCovT:
for constr in self.attr_constrs:
if constr.can_infer(variables.keys()):
return constr.infer(variables)
if constr.can_infer(context.variables.keys()):
return constr.infer(context)
raise ValueError("Cannot infer attribute from constraint")

def get_unique_base(self) -> type[Attribute] | None:
Expand Down Expand Up @@ -535,10 +544,8 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
constr.can_infer(var_constraint_names) for constr in self.param_constrs
)

def infer(
self, variables: dict[str, ConstraintVariableType]
) -> ParametrizedAttributeCovT:
params = tuple(constr.infer(variables) for constr in self.param_constrs)
def infer(self, context: InferenceContext) -> ParametrizedAttributeCovT:
params = tuple(constr.infer(context) for constr in self.param_constrs)
attr = self.base_attr.new(params)
return attr

Expand Down Expand Up @@ -590,8 +597,8 @@ def get_unique_base(self) -> type[Attribute] | None:
def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.constr.can_infer(var_constraint_names)

def infer(self, variables: dict[str, ConstraintVariableType]) -> AttributeCovT:
return self.constr.infer(variables)
def infer(self, context: InferenceContext) -> AttributeCovT:
return self.constr.infer(context)


@dataclass(frozen=True)
Expand Down Expand Up @@ -628,9 +635,7 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
# By default, we cannot infer anything.
return False

def infer(
self, length: int, variables: dict[str, ConstraintVariableType]
) -> Sequence[AttributeCovT]:
def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]:
"""
Infer the attribute given the the values for all variables.
Expand Down Expand Up @@ -684,12 +689,8 @@ def get_variable_extractors(
def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.name in var_constraint_names

def infer(
self,
length: int,
variables: dict[str, ConstraintVariableType],
) -> Sequence[AttributeCovT]:
v = variables[self.name]
def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]:
v = context.variables[self.name]
return cast(Sequence[AttributeCovT], v)


Expand All @@ -715,9 +716,9 @@ def can_infer(self, var_constraint_names: Set[str]) -> bool:
def infer(
self,
length: int,
variables: dict[str, ConstraintVariableType],
context: InferenceContext,
) -> Sequence[AttributeCovT]:
attr = self.constr.infer(variables)
attr = self.constr.infer(context)
return (attr,) * length


Expand Down Expand Up @@ -756,12 +757,8 @@ def get_variable_extractors(
def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.constr.can_infer(var_constraint_names)

def infer(
self,
length: int,
variables: dict[str, ConstraintVariableType],
) -> Sequence[AttributeCovT]:
return (self.constr.infer(variables),)
def infer(self, length: int, context: InferenceContext) -> Sequence[AttributeCovT]:
return (self.constr.infer(context),)


def range_constr_coercion(
Expand Down
9 changes: 6 additions & 3 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from xdsl.irdl import (
ConstraintVariableType,
InferenceContext,
IRDLOperation,
IRDLOperationInvT,
OpDef,
Expand Down Expand Up @@ -186,7 +187,7 @@ def resolve_operand_types(self, state: ParsingState, op_def: OpDef) -> None:
range_length = len(operand) if isinstance(operand, Sequence) else 1
operand_type = operand_def.constr.infer(
range_length,
state.variables,
InferenceContext(state.variables),
)
resolved_operand_type: Attribute | Sequence[Attribute]
if isinstance(operand_def, OptionalDef):
Expand Down Expand Up @@ -220,7 +221,7 @@ def resolve_result_types(self, state: ParsingState, op_def: OpDef) -> None:
range_length = 1
inferred_result_types = result_def.constr.infer(
range_length,
state.variables,
InferenceContext(state.variables),
)
resolved_result_type = inferred_result_types[0]
state.result_types[i] = resolved_result_type
Expand Down Expand Up @@ -885,7 +886,9 @@ def parse(self, parser: Parser, state: ParsingState) -> None:
):
attr = unique_base.new(unique_base.parse_parameters(parser))
elif issubclass(unique_base, Data):
attr = unique_base.new(unique_base.parse_parameter(parser)) # pyright: ignore[reportUnknownVariableType]
attr = unique_base.new( # pyright: ignore[reportUnknownVariableType]
unique_base.parse_parameter(parser)
)
else:
raise ValueError("Attributes must be Data or ParameterizedAttribute.")
if self.is_property:
Expand Down
3 changes: 2 additions & 1 deletion xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AttrSizedOperandSegments,
AttrSizedSegments,
ConstraintVariableType,
InferenceContext,
OpDef,
OptionalDef,
OptOperandDef,
Expand Down Expand Up @@ -528,7 +529,7 @@ def parse_optional_variable(
unique_base.get_type_index()
]
if type_constraint.can_infer(set()):
unique_type = type_constraint.infer({})
unique_type = type_constraint.infer(InferenceContext())
if (
unique_base is not None
and unique_base in Builtin.attributes
Expand Down

0 comments on commit 150b84b

Please sign in to comment.