Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (builtin) add CompileTimeFixedBitwidthType #3599

Merged
merged 7 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Float80Type,
Float128Type,
FloatAttr,
IndexType,
IntAttr,
IntegerAttr,
IntegerType,
Expand Down Expand Up @@ -78,6 +79,10 @@ def test_IntegerType_formats():
assert IntegerType(64).format == "<q"


def test_IndexType_formats():
assert IndexType().format == "<q"


def test_FloatType_packing():
nums = (-128, -1, 0, 1, 127)
buffer = f32.pack(nums)
Expand Down
35 changes: 30 additions & 5 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,23 @@ def print_parameter(self, printer: Printer) -> None:
raise ValueError(f"Invalid signedness {data}")


class FixedBitwidthType(TypeAttribute, ABC):
class CompileTimeFixedBitwidthType(TypeAttribute, ABC):
"""
A type attribute whose runtime bitwidth is fixed, but may be target-dependent.
"""

name = "abstract.compile_time_fixed_bitwidth_type"

@property
@abstractmethod
def compile_time_size(self) -> int:
"""
Contiguous memory footprint of the value during compilation.
"""
raise NotImplementedError()


class FixedBitwidthType(CompileTimeFixedBitwidthType, ABC):
"""
A type attribute whose runtime bitwidth is target-independent.
"""
Expand All @@ -375,7 +391,7 @@ def size(self) -> int:
_PyT = TypeVar("_PyT")


class PackableType(Generic[_PyT], FixedBitwidthType, ABC):
class PackableType(Generic[_PyT], CompileTimeFixedBitwidthType, ABC):
"""
Abstract base class for xDSL types whose values can be encoded and decoded as bytes.
"""
Expand Down Expand Up @@ -439,9 +455,13 @@ def pack(self, values: Sequence[_PyT]) -> bytes:
fmt = self.format[0] + str(len(values)) + self.format[1:]
return struct.pack(fmt, *values)

@property
def compile_time_size(self) -> int:
return struct.calcsize(self.format)


@irdl_attr_definition
class IntegerType(ParametrizedAttribute, StructPackableType[int]):
class IntegerType(ParametrizedAttribute, StructPackableType[int], FixedBitwidthType):
name = "integer_type"
width: ParameterDef[IntAttr]
signedness: ParameterDef[SignednessAttr]
Expand Down Expand Up @@ -556,7 +576,7 @@ class LocationAttr(ParametrizedAttribute):


@irdl_attr_definition
class IndexType(ParametrizedAttribute):
class IndexType(ParametrizedAttribute, StructPackableType[int]):
name = "index"

def print_value_without_type(self, value: int, printer: Printer):
Expand All @@ -565,6 +585,11 @@ def print_value_without_type(self, value: int, printer: Printer):
"""
printer.print_string(f"{value}")

@property
def format(self) -> str:
# index types are always packable as int64
return "<q"


IndexTypeConstr = BaseAttr(IndexType)

Expand Down Expand Up @@ -669,7 +694,7 @@ def constr(
BoolAttr: TypeAlias = IntegerAttr[Annotated[IntegerType, IntegerType(1)]]


class _FloatType(StructPackableType[float], ABC):
class _FloatType(StructPackableType[float], FixedBitwidthType, ABC):
@property
@abstractmethod
def bitwidth(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions xdsl/interpreters/utils/ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TypedPtr(Generic[_T]):

@property
def size(self) -> int:
return self.xtype.size
return self.xtype.compile_time_size

def copy(self) -> Self:
return type(self)(self.raw.copy(), xtype=self.xtype)
Expand All @@ -122,15 +122,15 @@ def __setitem__(self, index: int, value: _T):

@staticmethod
def zeros(count: int, *, xtype: PackableType[_T]) -> TypedPtr[_T]:
size = xtype.size
size = xtype.compile_time_size
return TypedPtr(RawPtr.zeros(size * count), xtype=xtype)

@staticmethod
def new(els: Sequence[_T], *, xtype: PackableType[_T]) -> TypedPtr[_T]:
"""
Returns a new TypedPtr with the specified els packed into memory.
"""
el_size = xtype.size
el_size = xtype.compile_time_size
res = RawPtr.zeros(len(els) * el_size)
for i, el in enumerate(els):
xtype.pack_into(res.memory, i * el_size, el)
Expand Down
Loading