Skip to content

Commit

Permalink
core: make DenseIntOrFPElementsAttr generic on element type
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 22, 2024
1 parent bc5ea1e commit 93a17b5
Show file tree
Hide file tree
Showing 16 changed files with 120 additions and 98 deletions.
4 changes: 2 additions & 2 deletions docs/Toy/toy/dialects/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ class ConstantOp(IRDLOperation):
"""

name = "toy.constant"
value = attr_def(DenseIntOrFPElementsAttr)
value = attr_def(DenseIntOrFPElementsAttr[Float64Type])
res = result_def(TensorTypeF64)

traits = traits_def(Pure())

def __init__(self, value: DenseIntOrFPElementsAttr):
def __init__(self, value: DenseIntOrFPElementsAttr[Float64Type]):
super().__init__(result_types=[value.type], attributes={"value": value})

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ def attribute_value_to_str(self, attr: Attribute) -> str:
return str(val.data)
case StringAttr() as s:
return f'"{s.data}"'
case DenseIntOrFPElementsAttr(data=ArrayAttr(data=data), type=typ):
return f"{self.mlir_type_to_csl_type(typ)} {{ {', '.join(self.attribute_value_to_str(d) for d in data)} }}"
case DenseIntOrFPElementsAttr(data=ArrayAttr(data=data)):
return f"{self.mlir_type_to_csl_type(attr.get_type())} {{ {', '.join(self.attribute_value_to_str(d) for d in data)} }}"
case _:
return f"<!unknown value {attr}>"

Expand Down
6 changes: 3 additions & 3 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
AnyIntegerAttr,
ArrayAttr,
ContainerType,
DenseIntOrFPElementsAttr,
DenseIntElementsAttr,
IndexType,
IntegerAttr,
IntegerType,
Expand Down Expand Up @@ -224,9 +224,9 @@ class ParallelOp(IRDLOperation):

reductions = prop_def(ArrayAttr[StringAttr])
lowerBoundsMap = prop_def(AffineMapAttr)
lowerBoundsGroups = prop_def(DenseIntOrFPElementsAttr)
lowerBoundsGroups = prop_def(DenseIntElementsAttr)
upperBoundsMap = prop_def(AffineMapAttr)
upperBoundsGroups = prop_def(DenseIntOrFPElementsAttr)
upperBoundsGroups = prop_def(DenseIntElementsAttr)
steps = prop_def(ArrayAttr[IntegerAttr[IntegerType]])

res = var_result_def()
Expand Down
7 changes: 5 additions & 2 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import ClassVar, Literal, TypeVar, cast, overload

from xdsl.dialects.builtin import (
AnyDenseElement,
AnyFloat,
AnyFloatConstr,
AnyIntegerAttr,
Expand Down Expand Up @@ -132,7 +133,9 @@ class Constant(IRDLOperation):
@overload
def __init__(
self,
value: AnyIntegerAttr | FloatAttr[AnyFloat] | DenseIntOrFPElementsAttr,
value: AnyIntegerAttr
| FloatAttr[AnyFloat]
| DenseIntOrFPElementsAttr[AnyDenseElement],
value_type: None = None,
) -> None: ...

Expand Down Expand Up @@ -179,7 +182,7 @@ def parse(cls: type[Constant], parser: Parser) -> Constant:
value,
base(AnyIntegerAttr)
| base(FloatAttr[AnyFloat])
| base(DenseIntOrFPElementsAttr),
| base(DenseIntOrFPElementsAttr[AnyDenseElement]),
):
parser.raise_error("Invalid constant value", p0, parser.pos)

Expand Down
101 changes: 44 additions & 57 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,16 +1717,19 @@ def get_element_type(self) -> _UnrankedMemrefTypeElems:
)

AnyDenseElement: TypeAlias = IntegerType | IndexType | AnyFloat
DenseElementT = TypeVar("DenseElementT", bound=AnyDenseElement, covariant=True)
_DenseElementT = TypeVar("_DenseElementT", bound=AnyDenseElement)
FloatTypeT = TypeVar("FloatTypeT", bound=AnyFloat)


@irdl_attr_definition
class DenseIntOrFPElementsAttr(TypedAttribute, ContainerType[AnyDenseElement]):
class DenseIntOrFPElementsAttr(
Generic[DenseElementT],
TypedAttribute,
ContainerType[DenseElementT],
):
name = "dense"
type: ParameterDef[
RankedStructure[IntegerType]
| RankedStructure[IndexType]
| RankedStructure[AnyFloat]
]
type: ParameterDef[RankedStructure[DenseElementT]]
data: ParameterDef[ArrayAttr[AnyIntegerAttr] | ArrayAttr[AnyFloatAttr]]

# The type stores the shape data
Expand All @@ -1735,7 +1738,7 @@ def get_shape(self) -> tuple[int, ...] | None:
return None
return self.type.get_shape()

def get_element_type(self) -> IntegerType | IndexType | AnyFloat:
def get_element_type(self) -> DenseElementT:
return self.type.get_element_type()

@property
Expand All @@ -1758,21 +1761,21 @@ def shape_is_complete(self) -> bool:
def create_dense_index(
type: RankedStructure[IndexType],
data: Sequence[int] | Sequence[IntegerAttr[IndexType]],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[IndexType]:
if len(data) and isinstance(data[0], int):
attr_list = [
IntegerAttr.from_index_int_value(d) for d in cast(Sequence[int], data)
]
else:
attr_list = cast(Sequence[IntegerAttr[IndexType]], data)

return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)])
return DenseIntOrFPElementsAttr[IndexType]([type, ArrayAttr(attr_list)])

@staticmethod
def create_dense_int(
type: RankedStructure[IntegerType],
data: Sequence[int] | Sequence[IntegerAttr[IntegerType]],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[IntegerType]:
if len(data) and isinstance(data[0], int):
attr_list = [
IntegerAttr[IntegerType](d, type.element_type)
Expand All @@ -1785,9 +1788,9 @@ def create_dense_int(

@staticmethod
def create_dense_float(
type: RankedStructure[AnyFloat],
type: RankedStructure[FloatTypeT],
data: Sequence[int | float] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[FloatTypeT]:
if len(data) and isinstance(data[0], int | float):
attr_list = [
FloatAttr(float(d), type.element_type)
Expand All @@ -1798,64 +1801,40 @@ def create_dense_float(

return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)])

@overload
@staticmethod
def from_list(
type: (
RankedStructure[AnyFloat | IntegerType | IndexType]
| RankedStructure[AnyFloat]
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
data: (
Sequence[int]
| Sequence[IntegerAttr[IndexType]]
| Sequence[IntegerAttr[IntegerType]]
),
) -> DenseIntOrFPElementsAttr: ...

@overload
@staticmethod
def from_list(
type: (
RankedStructure[AnyFloat | IntegerType | IndexType]
| RankedStructure[AnyFloat]
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
data: Sequence[int | float] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr: ...

@staticmethod
def from_list(
type: (
RankedStructure[AnyFloat | IntegerType | IndexType]
| RankedStructure[AnyFloat]
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
type: RankedStructure[_DenseElementT],
data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[_DenseElementT]:
if isinstance(type.element_type, AnyFloat):
new_type = cast(RankedStructure[AnyFloat], type)
new_data = cast(Sequence[int | float] | Sequence[FloatAttr[AnyFloat]], data)
return DenseIntOrFPElementsAttr.create_dense_float(new_type, new_data)
return cast(
DenseIntOrFPElementsAttr[_DenseElementT],
DenseIntOrFPElementsAttr.create_dense_float(new_type, new_data),
)
elif isinstance(type.element_type, IntegerType):
new_type = cast(RankedStructure[IntegerType], type)
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IntegerType]], data)
return DenseIntOrFPElementsAttr.create_dense_int(new_type, new_data)
return cast(
DenseIntOrFPElementsAttr[_DenseElementT],
DenseIntOrFPElementsAttr.create_dense_int(new_type, new_data),
)
else:
new_type = cast(RankedStructure[IndexType], type)
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IndexType]], data)
return DenseIntOrFPElementsAttr.create_dense_index(new_type, new_data)
return cast(
DenseIntOrFPElementsAttr[_DenseElementT],
DenseIntOrFPElementsAttr.create_dense_index(new_type, new_data),
)

@staticmethod
def vector_from_list(
data: Sequence[int] | Sequence[float],
data_type: IntegerType | IndexType | AnyFloat,
) -> DenseIntOrFPElementsAttr:
data_type: _DenseElementT,
) -> DenseIntOrFPElementsAttr[_DenseElementT]:
t = VectorType(data_type, [len(data)])
return DenseIntOrFPElementsAttr.from_list(t, data)
return DenseIntOrFPElementsAttr[_DenseElementT].from_list(t, data)

@staticmethod
def tensor_from_list(
Expand All @@ -1866,15 +1845,20 @@ def tensor_from_list(
| Sequence[IntegerAttr[IntegerType]]
| Sequence[AnyFloatAttr]
),
data_type: IntegerType | IndexType | AnyFloat,
data_type: _DenseElementT,
shape: Sequence[int],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[_DenseElementT]:
t = TensorType(data_type, shape)
return DenseIntOrFPElementsAttr.from_list(t, data)
return DenseIntOrFPElementsAttr[_DenseElementT].from_list(t, data)

@staticmethod
def parse_with_type(parser: AttrParser, type: Attribute) -> TypedAttribute:
assert isa(type, RankedStructure[AnyDenseElement])
assert (
isa(type, VectorType[AnyDenseElement])
or isa(type, TensorType[AnyDenseElement])
or isa(type, MemRefType[AnyDenseElement])
)

return parser.parse_dense_int_or_fp_elements_attr(type)

@staticmethod
Expand Down Expand Up @@ -1926,6 +1910,9 @@ def print_without_type(self, printer: Printer):
printer.print_string(">")


DenseIntElementsAttr: TypeAlias = DenseIntOrFPElementsAttr[IntegerType]


Builtin = Dialect(
"builtin",
[
Expand Down
13 changes: 6 additions & 7 deletions xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from xdsl.dialects.builtin import (
DenseArrayBase,
DenseIntOrFPElementsAttr,
IndexType,
DenseIntElementsAttr,
IndexTypeConstr,
IntegerType,
SignlessIntegerConstraint,
Expand Down Expand Up @@ -178,7 +177,7 @@ class Switch(IRDLOperation):

name = "cf.switch"

case_values = opt_prop_def(DenseIntOrFPElementsAttr)
case_values = opt_prop_def(DenseIntElementsAttr)

flag = operand_def(IndexTypeConstr | SignlessIntegerConstraint)

Expand All @@ -202,7 +201,7 @@ def __init__(
flag: Operation | SSAValue,
default_block: Successor,
default_operands: Sequence[Operation | SSAValue],
case_values: DenseIntOrFPElementsAttr | None = None,
case_values: DenseIntElementsAttr | None = None,
case_blocks: Sequence[Successor] = [],
case_operands: Sequence[Sequence[Operation | SSAValue]] = [],
attr_dict: dict[str, Attribute] | None = None,
Expand Down Expand Up @@ -355,15 +354,15 @@ def parse(cls, parser: Parser) -> Self:
parser.parse_punctuation("[")
parser.parse_keyword("default")
(default_block, default_args) = cls._parse_case_body(parser)
case_values: DenseIntOrFPElementsAttr | None = None
case_values: DenseIntElementsAttr | None = None
case_blocks: tuple[Block, ...] = ()
case_operands: tuple[tuple[SSAValue, ...], ...] = ()
if parser.parse_optional_punctuation(","):
cases = parser.parse_comma_separated_list(
Parser.Delimiter.NONE, lambda: cls._parse_case(parser)
)
assert isinstance(flag_type, IntegerType | IndexType)
case_values = DenseIntOrFPElementsAttr.vector_from_list(
assert isinstance(flag_type, IntegerType)
case_values = DenseIntElementsAttr.vector_from_list(
[x for (x, _, _) in cases], flag_type
)
case_blocks = tuple(x for (_, x, _) in cases)
Expand Down
10 changes: 5 additions & 5 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
AnyTensorType,
ArrayAttr,
DenseArrayBase,
DenseIntOrFPElementsAttr,
DenseIntElementsAttr,
IntegerType,
MemRefType,
ShapedType,
Expand Down Expand Up @@ -961,8 +961,8 @@ class PoolingOpsBase(IRDLOperation, ABC):
"`outs` `(` $outputs `:` type($outputs) `)` `->` type($res)"
)

strides = attr_def(DenseIntOrFPElementsAttr)
dilations = attr_def(DenseIntOrFPElementsAttr)
strides = attr_def(DenseIntElementsAttr)
dilations = attr_def(DenseIntElementsAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]

Expand Down Expand Up @@ -1012,8 +1012,8 @@ class ConvOpsBase(IRDLOperation, ABC):
"`outs` `(` $outputs `:` type($outputs) `)` `->` type($res)"
)

strides = attr_def(DenseIntOrFPElementsAttr)
dilations = attr_def(DenseIntOrFPElementsAttr)
strides = attr_def(DenseIntElementsAttr)
dilations = attr_def(DenseIntElementsAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]

Expand Down
6 changes: 4 additions & 2 deletions xdsl/dialects/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from xdsl.dialects.builtin import (
Any,
AnyDenseElement,
AnyFloat,
AnyFloatConstr,
AnyIntegerAttr,
Expand Down Expand Up @@ -44,6 +45,7 @@
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.isattr import isattr


def verify_unidirectional_broadcast_shape(
Expand Down Expand Up @@ -591,7 +593,7 @@ class Constant(IRDLOperation):
name = "onnx.Constant"
output = result_def(AnyTensorType)

value = opt_attr_def(DenseIntOrFPElementsAttr)
value = opt_attr_def(DenseIntOrFPElementsAttr[AnyDenseElement])
value_float = opt_attr_def(FloatAttr[Float32Type])
value_floats = opt_attr_def(ArrayAttr[FloatAttr[Float32Type]])
value_int = opt_attr_def(IntegerAttr[IntegerType])
Expand Down Expand Up @@ -664,7 +666,7 @@ def print(self, printer: Printer):
@classmethod
def parse(cls, parser: Parser) -> Self:
v = parser.parse_attribute()
if not isinstance(v, DenseIntOrFPElementsAttr):
if not isattr(v, base(DenseIntOrFPElementsAttr[AnyDenseElement])):
raise NotImplementedError()
constant = cls(v, None, None, None, None, None, None, v.type)
return constant
Expand Down
Loading

0 comments on commit 93a17b5

Please sign in to comment.