Skip to content

Commit

Permalink
chore: upgrade to pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jan 10, 2024
1 parent 137941f commit 12ee66c
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 18 deletions.
4 changes: 3 additions & 1 deletion guppy/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ def __init__(self, op: ops.OpType) -> None:
self.op = op

def compile(self, args: list[OutPortV]) -> list[OutPortV]:
node = self.graph.add_node(self.op.copy(), inputs=args, parent=self.dfg.node)
node = self.graph.add_node(
self.op.model_copy(), inputs=args, parent=self.dfg.node
)
return_ty = get_type(self.node)
return [node.add_out_port(ty) for ty in type_to_row(return_ty)]

Expand Down
6 changes: 3 additions & 3 deletions guppy/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.inputs = inputs
pred = outputs[0]
assert isinstance(pred, tys.Sum)
assert isinstance(pred, tys.UnitSum | tys.GeneralSum)
if isinstance(pred, tys.UnitSum):
self.tuple_sum_rows = [[] for _ in range(pred.size)]
else:
Expand Down Expand Up @@ -555,7 +555,7 @@ class TypeApplication(BaseModel):
# --------------------------------------


class OpDef(BaseOp, allow_population_by_field_name=True):
class OpDef(BaseOp, populate_by_name=True):
"""Serializable definition for dynamically loaded operations."""

name: str # Unique identifier of the operation.
Expand All @@ -577,4 +577,4 @@ class OpDef(BaseOp, allow_population_by_field_name=True):
)
for _, c in classes:
if issubclass(c, BaseModel):
c.update_forward_refs()
c.model_rebuild()
2 changes: 1 addition & 1 deletion guppy/hugr/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RawHugr(BaseModel):
edges: list[Edge]

def packb(self) -> bytes:
return ormsgpack.packb(self.dict(), option=ormsgpack.OPT_NON_STR_KEYS)
return ormsgpack.packb(self.model_dump(), option=ormsgpack.OPT_NON_STR_KEYS)

@classmethod
def unpackb(cls, b: bytes) -> "RawHugr":
Expand Down
18 changes: 8 additions & 10 deletions guppy/hugr/tys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
import sys
from abc import ABC
from enum import Enum
from typing import Annotated, Literal

Expand All @@ -24,7 +23,7 @@ class TypeParam(BaseModel):

class BoundedNatParam(BaseModel):
tp: Literal["BoundedNat"] = "BoundedNat"
bound: int | None
bound: int | None = None


class OpaqueParam(BaseModel):
Expand Down Expand Up @@ -118,26 +117,25 @@ class Tuple(BaseModel):
inner: "TypeRow"


class Sum(ABC, BaseModel):
"""Sum type, variants are tagged by their position in the type row"""
class UnitSum(BaseModel):
"""Simple predicate where all variants are empty tuples"""

t: Literal["Sum"] = "Sum"


class UnitSum(Sum):
"""Simple predicate where all variants are empty tuples"""

s: Literal["Unit"] = "Unit"
size: int


class GeneralSum(Sum):
class GeneralSum(BaseModel):
"""General sum type that explicitly stores the types of the variants"""

t: Literal["Sum"] = "Sum"

s: Literal["General"] = "General"
row: "TypeRow"


Sum = Annotated[UnitSum | GeneralSum, Field(discriminator="s")]
# ----------------------------------------------
# --------------- ClassicType ------------------
# ----------------------------------------------
Expand Down Expand Up @@ -284,4 +282,4 @@ class Signature(BaseModel):
)
for _, c in classes:
if issubclass(c, BaseModel):
c.update_forward_refs()
c.model_rebuild()
4 changes: 2 additions & 2 deletions guppy/hugr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class FunctionVal(BaseModel):
"""A higher-order function value."""

v: Literal["Function"] = "Function"
hugr: Any # TODO
hugr: Any = None # TODO


class Tuple(BaseModel):
Expand Down Expand Up @@ -50,4 +50,4 @@ class Sum(BaseModel):
)
for _, c in classes:
if issubclass(c, BaseModel):
c.update_forward_refs()
c.model_rebuild()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "guppy"
authors = [{ name = "Mark Koch", email = "[email protected]" }]
version = "0.1.0"
requires-python = ">=3.9"
dependencies = ["graphviz", "networkx", "ormsgpack", "pydantic==1.10.8"]
dependencies = ["graphviz", "networkx", "ormsgpack", "pydantic==2.5.3"]
readme = "README.md"
license = { text = "Apache-2.0" }

Expand Down

0 comments on commit 12ee66c

Please sign in to comment.