diff --git a/guppy/custom.py b/guppy/custom.py index 8c5a4fc1..f228f46f 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -239,7 +239,9 @@ def __init__(self, op: ops.OpType) -> None: self.op = op def compile(self, args: list[OutPortV]) -> list[OutPortV]: - node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node) + node = self.graph.add_node( + self.op.model_copy(), inputs=args, parent=self.dfg.node + ) return_ty = get_type(self.node) return [node.add_out_port(ty) for ty in type_to_row(return_ty)] diff --git a/guppy/hugr/ops.py b/guppy/hugr/ops.py index 02b393e3..f0391799 100644 --- a/guppy/hugr/ops.py +++ b/guppy/hugr/ops.py @@ -13,6 +13,7 @@ FunctionType, PolyFuncType, SimpleType, + SumUnion, TypeRow, ) from .val import Value @@ -126,7 +127,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: self.inputs = inputs pred = outputs[0] - assert isinstance(pred, tys.Sum) + assert isinstance(pred, SumUnion) if isinstance(pred, tys.UnitSum): self.tuple_sum_rows = [[] for _ in range(pred.size)] else: @@ -555,7 +556,7 @@ class TypeApplication(BaseModel): # -------------------------------------- -class OpDef(BaseOp, allow_population_by_field_name=True): +class OpDef(BaseOp, populate_by_name=True): """Serializable definition for dynamically loaded operations.""" name: str # Unique identifier of the operation. @@ -577,4 +578,4 @@ class OpDef(BaseOp, allow_population_by_field_name=True): ) for _, c in classes: if issubclass(c, BaseModel): - c.update_forward_refs() + c.model_rebuild() diff --git a/guppy/hugr/raw.py b/guppy/hugr/raw.py index a3249a3c..bedee643 100644 --- a/guppy/hugr/raw.py +++ b/guppy/hugr/raw.py @@ -15,7 +15,7 @@ class RawHugr(BaseModel): edges: list[Edge] def packb(self) -> bytes: - return ormsgpack.packb(self.dict(), option=ormsgpack.OPT_NON_STR_KEYS) + return ormsgpack.packb(self.model_dump(), option=ormsgpack.OPT_NON_STR_KEYS) @classmethod def unpackb(cls, b: bytes) -> "RawHugr": diff --git a/guppy/hugr/tys.py b/guppy/hugr/tys.py index 9098e552..750029f1 100644 --- a/guppy/hugr/tys.py +++ b/guppy/hugr/tys.py @@ -1,6 +1,5 @@ import inspect import sys -from abc import ABC from enum import Enum from typing import Annotated, Literal @@ -24,7 +23,7 @@ class TypeParam(BaseModel): class BoundedNatParam(BaseModel): tp: Literal["BoundedNat"] = "BoundedNat" - bound: int | None + bound: int | None = None class OpaqueParam(BaseModel): @@ -118,26 +117,32 @@ class Tuple(BaseModel): inner: "TypeRow" -class Sum(ABC, BaseModel): - """Sum type, variants are tagged by their position in the type row""" +# class Sum(ABC, BaseModel): +# """Sum type, variants are tagged by their position in the type row""" - t: Literal["Sum"] = "Sum" +# t: Literal["Sum"] = "Sum" -class UnitSum(Sum): +class UnitSum(BaseModel): """Simple predicate where all variants are empty tuples""" + t: Literal["Sum"] = "Sum" + s: Literal["Unit"] = "Unit" size: int -class GeneralSum(Sum): +class GeneralSum(BaseModel): """General sum type that explicitly stores the types of the variants""" + t: Literal["Sum"] = "Sum" + s: Literal["General"] = "General" row: "TypeRow" +SumUnion = UnitSum | GeneralSum +Sum = Annotated[SumUnion, Field(discriminator="s")] # ---------------------------------------------- # --------------- ClassicType ------------------ # ---------------------------------------------- @@ -284,4 +289,4 @@ class Signature(BaseModel): ) for _, c in classes: if issubclass(c, BaseModel): - c.update_forward_refs() + c.model_rebuild() diff --git a/guppy/hugr/val.py b/guppy/hugr/val.py index b8b5eabb..bb6a2cd2 100644 --- a/guppy/hugr/val.py +++ b/guppy/hugr/val.py @@ -18,7 +18,7 @@ class FunctionVal(BaseModel): """A higher-order function value.""" v: Literal["Function"] = "Function" - hugr: Any # TODO + hugr: Any = None # TODO class Tuple(BaseModel): @@ -50,4 +50,4 @@ class Sum(BaseModel): ) for _, c in classes: if issubclass(c, BaseModel): - c.update_forward_refs() + c.model_rebuild() diff --git a/pyproject.toml b/pyproject.toml index 6d7333c8..88c8d8d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "graphviz", "networkx", "ormsgpack", - "pydantic==1.10.8", + "pydantic==2.5.3", ] [project.optional-dependencies]