From 1aa5619b9bfcc89844763affead1790061f2087e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Fri, 20 Dec 2024 13:11:33 +0100 Subject: [PATCH] fix(py): Fix array/list value serialization --- hugr-py/src/hugr/std/collections/array.py | 8 ++++++-- hugr-py/src/hugr/std/collections/list.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/hugr-py/src/hugr/std/collections/array.py b/hugr-py/src/hugr/std/collections/array.py index f7638e4f7..3e7f3bcfc 100644 --- a/hugr-py/src/hugr/std/collections/array.py +++ b/hugr-py/src/hugr/std/collections/array.py @@ -60,7 +60,7 @@ class ArrayVal(val.ExtensionValue): """Constant value for a statically sized array of elements.""" v: list[val.Value] - ty: tys.Type + ty: Array def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None: self.v = v @@ -71,7 +71,11 @@ def to_value(self) -> val.Extension: # The value list must be serialized at this point, otherwise the # `Extension` value would not be serializable. vs = [v._to_serial_root() for v in self.v] - return val.Extension(name, typ=self.ty, val=vs, extensions=[EXTENSION.name]) + element_ty = self.ty.ty._to_serial_root() + serial_val = {"values": vs, "typ": element_ty} + return val.Extension( + name, typ=self.ty, val=serial_val, extensions=[EXTENSION.name] + ) def __str__(self) -> str: return f"array({comma_sep_str(self.v)})" diff --git a/hugr-py/src/hugr/std/collections/list.py b/hugr-py/src/hugr/std/collections/list.py index dccf5ae44..e091bc365 100644 --- a/hugr-py/src/hugr/std/collections/list.py +++ b/hugr-py/src/hugr/std/collections/list.py @@ -39,7 +39,7 @@ class ListVal(val.ExtensionValue): """Constant value for a list of elements.""" v: list[val.Value] - ty: tys.Type + ty: List def __init__(self, v: list[val.Value], elem_ty: tys.Type) -> None: self.v = v @@ -50,7 +50,11 @@ def to_value(self) -> val.Extension: # The value list must be serialized at this point, otherwise the # `Extension` value would not be serializable. vs = [v._to_serial_root() for v in self.v] - return val.Extension(name, typ=self.ty, val=vs, extensions=[EXTENSION.name]) + element_ty = self.ty.ty._to_serial_root() + serial_val = {"values": vs, "typ": element_ty} + return val.Extension( + name, typ=self.ty, val=serial_val, extensions=[EXTENSION.name] + ) def __str__(self) -> str: return f"[{comma_sep_str(self.v)}]"