diff --git a/guppy/custom.py b/guppy/custom.py index 8c5a4fc1..f228f46f 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -239,7 +239,9 @@ def __init__(self, op: ops.OpType) -> None: self.op = op def compile(self, args: list[OutPortV]) -> list[OutPortV]: - node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node) + node = self.graph.add_node( + self.op.model_copy(), inputs=args, parent=self.dfg.node + ) return_ty = get_type(self.node) return [node.add_out_port(ty) for ty in type_to_row(return_ty)] diff --git a/guppy/hugr/ops.py b/guppy/hugr/ops.py index 02b393e3..ce6d3780 100644 --- a/guppy/hugr/ops.py +++ b/guppy/hugr/ops.py @@ -126,7 +126,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: self.inputs = inputs pred = outputs[0] - assert isinstance(pred, tys.Sum) + assert isinstance(pred, tys.UnitSum | tys.GeneralSum) if isinstance(pred, tys.UnitSum): self.tuple_sum_rows = [[] for _ in range(pred.size)] else: @@ -555,7 +555,7 @@ class TypeApplication(BaseModel): # -------------------------------------- -class OpDef(BaseOp, allow_population_by_field_name=True): +class OpDef(BaseOp, populate_by_name=True): """Serializable definition for dynamically loaded operations.""" name: str # Unique identifier of the operation. @@ -577,4 +577,4 @@ class OpDef(BaseOp, allow_population_by_field_name=True): ) for _, c in classes: if issubclass(c, BaseModel): - c.update_forward_refs() + c.model_rebuild() diff --git a/guppy/hugr/raw.py b/guppy/hugr/raw.py index a3249a3c..bedee643 100644 --- a/guppy/hugr/raw.py +++ b/guppy/hugr/raw.py @@ -15,7 +15,7 @@ class RawHugr(BaseModel): edges: list[Edge] def packb(self) -> bytes: - return ormsgpack.packb(self.dict(), option=ormsgpack.OPT_NON_STR_KEYS) + return ormsgpack.packb(self.model_dump(), option=ormsgpack.OPT_NON_STR_KEYS) @classmethod def unpackb(cls, b: bytes) -> "RawHugr": diff --git a/guppy/hugr/tys.py b/guppy/hugr/tys.py index 9098e552..9ccfbd93 100644 --- a/guppy/hugr/tys.py +++ b/guppy/hugr/tys.py @@ -1,6 +1,5 @@ import inspect import sys -from abc import ABC from enum import Enum from typing import Annotated, Literal @@ -24,7 +23,7 @@ class TypeParam(BaseModel): class BoundedNatParam(BaseModel): tp: Literal["BoundedNat"] = "BoundedNat" - bound: int | None + bound: int | None = None class OpaqueParam(BaseModel): @@ -118,26 +117,25 @@ class Tuple(BaseModel): inner: "TypeRow" -class Sum(ABC, BaseModel): - """Sum type, variants are tagged by their position in the type row""" +class UnitSum(BaseModel): + """Simple predicate where all variants are empty tuples""" t: Literal["Sum"] = "Sum" - -class UnitSum(Sum): - """Simple predicate where all variants are empty tuples""" - s: Literal["Unit"] = "Unit" size: int -class GeneralSum(Sum): +class GeneralSum(BaseModel): """General sum type that explicitly stores the types of the variants""" + t: Literal["Sum"] = "Sum" + s: Literal["General"] = "General" row: "TypeRow" +Sum = Annotated[UnitSum | GeneralSum, Field(discriminator="s")] # ---------------------------------------------- # --------------- ClassicType ------------------ # ---------------------------------------------- @@ -284,4 +282,4 @@ class Signature(BaseModel): ) for _, c in classes: if issubclass(c, BaseModel): - c.update_forward_refs() + c.model_rebuild() diff --git a/guppy/hugr/val.py b/guppy/hugr/val.py index b8b5eabb..bb6a2cd2 100644 --- a/guppy/hugr/val.py +++ b/guppy/hugr/val.py @@ -18,7 +18,7 @@ class FunctionVal(BaseModel): """A higher-order function value.""" v: Literal["Function"] = "Function" - hugr: Any # TODO + hugr: Any = None # TODO class Tuple(BaseModel): @@ -50,4 +50,4 @@ class Sum(BaseModel): ) for _, c in classes: if issubclass(c, BaseModel): - c.update_forward_refs() + c.model_rebuild() diff --git a/pyproject.toml b/pyproject.toml index 0d4282b2..ed785b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "guppy" authors = [{ name = "Mark Koch", email = "mark.koch@quantinuum.com" }] version = "0.1.0" requires-python = ">=3.9" -dependencies = ["graphviz", "networkx", "ormsgpack", "pydantic==1.10.8"] +dependencies = ["graphviz", "networkx", "ormsgpack", "pydantic==2.5.3"] readme = "README.md" license = { text = "Apache-2.0" }