From 70e0a64cf3aaa8d8be8c999684a6c173d7181663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 3 Sep 2024 15:18:20 +0100 Subject: [PATCH] feat: Export the collections extension (#1506) #1450 embedded the standard extension definitions in hugr-py, including `collections`, but it didn't add a way to load it as it did with all the others. This PR adds a `hugr.std.collections` module that just loads the bundled json. drive-by: Implement `__str__` for `FloatVal` and `IntVal` --- hugr-py/src/hugr/std/collections.py | 40 +++++++++++++++++++++++++++++ hugr-py/src/hugr/std/float.py | 3 +++ hugr-py/src/hugr/std/int.py | 5 ++++ 3 files changed, 48 insertions(+) create mode 100644 hugr-py/src/hugr/std/collections.py diff --git a/hugr-py/src/hugr/std/collections.py b/hugr-py/src/hugr/std/collections.py new file mode 100644 index 000000000..ba820845c --- /dev/null +++ b/hugr-py/src/hugr/std/collections.py @@ -0,0 +1,40 @@ +"""Collection types and operations.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import hugr.tys as tys +from hugr import val +from hugr.std import _load_extension +from hugr.utils import comma_sep_str + +EXTENSION = _load_extension("collections") + + +def list_type(ty: tys.Type) -> tys.ExtType: + """Returns a list type with a fixed element type.""" + arg = tys.TypeTypeArg(ty) + return EXTENSION.types["List"].instantiate([arg]) + + +@dataclass +class ListVal(val.ExtensionValue): + """Constant value for a list of elements.""" + + v: list[val.Value] + ty: tys.Type + + def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None: + self.v = v + self.ty = list_type(elem_ty) + + def to_value(self) -> val.Extension: + name = "ListValue" + # The value list must be serialized at this point, otherwise the + # `Extension` value would not be serializable. + vs = [v._to_serial_root() for v in self.v] + return val.Extension(name, typ=self.ty, val=vs, extensions=[EXTENSION.name]) + + def __str__(self) -> str: + return f"[{comma_sep_str(self.v)}]" diff --git a/hugr-py/src/hugr/std/float.py b/hugr-py/src/hugr/std/float.py index fa50a309e..a8eeac5c1 100644 --- a/hugr-py/src/hugr/std/float.py +++ b/hugr-py/src/hugr/std/float.py @@ -24,3 +24,6 @@ def to_value(self) -> val.Extension: return val.Extension( name, typ=FLOAT_T, val=payload, extensions=[EXTENSION.name] ) + + def __str__(self) -> str: + return f"{self.v}" diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 699f6bd45..3f437df92 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -14,6 +14,8 @@ if TYPE_CHECKING: from hugr.ops import Command, ComWire +CONVERSIONS_EXTENSION = _load_extension("arithmetic.conversions") + INT_TYPES_EXTENSION = _load_extension("arithmetic.int.types") _INT_PARAM = tys.BoundedNatParam(7) @@ -66,6 +68,9 @@ def to_value(self) -> val.Extension: extensions=[INT_TYPES_EXTENSION.name], ) + def __str__(self) -> str: + return f"{self.v}" + INT_OPS_EXTENSION = _load_extension("arithmetic.int")