Skip to content

Commit

Permalink
chore: upgrade to pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jan 10, 2024
1 parent d0cf104 commit d7e1a41
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
4 changes: 3 additions & 1 deletion guppy/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ def __init__(self, op: ops.OpType) -> None:
self.op = op

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node)
node = self.graph.add_node(
self.op.model_copy(), inputs=args, parent=self.dfg.node
)
return_ty = get_type(self.node)
return [node.add_out_port(ty) for ty in type_to_row(return_ty)]

Expand Down
7 changes: 4 additions & 3 deletions guppy/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FunctionType,
PolyFuncType,
SimpleType,
SumUnion,
TypeRow,
)
from .val import Value
Expand Down Expand Up @@ -126,7 +127,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.inputs = inputs
pred = outputs[0]
assert isinstance(pred, tys.Sum)
assert isinstance(pred, SumUnion)
if isinstance(pred, tys.UnitSum):
self.tuple_sum_rows = [[] for _ in range(pred.size)]
else:
Expand Down Expand Up @@ -555,7 +556,7 @@ class TypeApplication(BaseModel):
# --------------------------------------


class OpDef(BaseOp, allow_population_by_field_name=True):
class OpDef(BaseOp, populate_by_name=True):
"""Serializable definition for dynamically loaded operations."""

name: str # Unique identifier of the operation.
Expand All @@ -577,4 +578,4 @@ class OpDef(BaseOp, allow_population_by_field_name=True):
)
for _, c in classes:
if issubclass(c, BaseModel):
c.update_forward_refs()
c.model_rebuild()
2 changes: 1 addition & 1 deletion guppy/hugr/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RawHugr(BaseModel):
edges: list[Edge]

def packb(self) -> bytes:
return ormsgpack.packb(self.dict(), option=ormsgpack.OPT_NON_STR_KEYS)
return ormsgpack.packb(self.model_dump(), option=ormsgpack.OPT_NON_STR_KEYS)

@classmethod
def unpackb(cls, b: bytes) -> "RawHugr":
Expand Down
21 changes: 13 additions & 8 deletions guppy/hugr/tys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import sys
from abc import ABC
from enum import Enum
from typing import Annotated, Literal

Expand All @@ -24,7 +23,7 @@ class TypeParam(BaseModel):

class BoundedNatParam(BaseModel):
tp: Literal["BoundedNat"] = "BoundedNat"
bound: int | None
bound: int | None = None


class OpaqueParam(BaseModel):
Expand Down Expand Up @@ -118,26 +117,32 @@ class Tuple(BaseModel):
inner: "TypeRow"


class Sum(ABC, BaseModel):
"""Sum type, variants are tagged by their position in the type row"""
# class Sum(ABC, BaseModel):
# """Sum type, variants are tagged by their position in the type row"""

t: Literal["Sum"] = "Sum"
# t: Literal["Sum"] = "Sum"


class UnitSum(Sum):
class UnitSum(BaseModel):
"""Simple predicate where all variants are empty tuples"""

t: Literal["Sum"] = "Sum"

s: Literal["Unit"] = "Unit"
size: int


class GeneralSum(Sum):
class GeneralSum(BaseModel):
"""General sum type that explicitly stores the types of the variants"""

t: Literal["Sum"] = "Sum"

s: Literal["General"] = "General"
row: "TypeRow"


SumUnion = UnitSum | GeneralSum
Sum = Annotated[SumUnion, Field(discriminator="s")]
# ----------------------------------------------
# --------------- ClassicType ------------------
# ----------------------------------------------
Expand Down Expand Up @@ -284,4 +289,4 @@ class Signature(BaseModel):
)
for _, c in classes:
if issubclass(c, BaseModel):
c.update_forward_refs()
c.model_rebuild()
4 changes: 2 additions & 2 deletions guppy/hugr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class FunctionVal(BaseModel):
"""A higher-order function value."""

v: Literal["Function"] = "Function"
hugr: Any # TODO
hugr: Any = None # TODO


class Tuple(BaseModel):
Expand Down Expand Up @@ -50,4 +50,4 @@ class Sum(BaseModel):
)
for _, c in classes:
if issubclass(c, BaseModel):
c.update_forward_refs()
c.model_rebuild()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"graphviz",
"networkx",
"ormsgpack",
"pydantic==1.10.8",
"pydantic==2.5.3",
]

[project.optional-dependencies]
Expand Down

0 comments on commit d7e1a41

Please sign in to comment.