Skip to content

Commit

Permalink
fix: Schema issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch authored and doug-q committed May 1, 2024
1 parent fb76c7e commit c92d312
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 118 deletions.
69 changes: 25 additions & 44 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from abc import ABC
from typing import Any, Literal, cast
from typing import Any, Literal

from pydantic import BaseModel, Field, RootModel

Expand All @@ -25,7 +25,7 @@ class BaseOp(ABC, BaseModel):

# Parent node index of node the op belongs to, used only at serialization time
parent: NodeID
input_extensions: ExtensionSet = Field(default_factory=ExtensionSet)
input_extensions: ExtensionSet | None = Field(default=None)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
"""Hook to insert type information from the input and output ports into the
Expand Down Expand Up @@ -56,29 +56,15 @@ class FuncDefn(BaseOp):
op: Literal["FuncDefn"] = "FuncDefn"

name: str
signature: PolyFuncType = Field(default_factory=PolyFuncType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
assert len(in_types) == 0
assert len(out_types) == 1
out = out_types[0]
assert isinstance(out, PolyFuncType)
self.signature = out # TODO: Extensions
signature: PolyFuncType


class FuncDecl(BaseOp):
"""External function declaration, linked at runtime."""

op: Literal["FuncDecl"] = "FuncDecl"
name: str
signature: PolyFuncType = Field(default_factory=PolyFuncType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
assert len(in_types) == 0
assert len(out_types) == 1
out = out_types[0]
assert isinstance(out, PolyFuncType)
self.signature = out
signature: PolyFuncType


CustomConst = Any # TODO
Expand Down Expand Up @@ -183,13 +169,13 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:

def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.inputs = inputs
pred = outputs[0]
assert isinstance(pred, tys.UnitSum | tys.GeneralSum)
if isinstance(pred, tys.UnitSum):
self.sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)]
pred = outputs[0].root
assert isinstance(pred, tys.SumType)
if isinstance(pred.root, tys.UnitSum):
self.sum_rows = [[] for _ in range(pred.root.size)]
else:
self.sum_rows = []
for variant in pred.rows:
for variant in pred.root.rows:
self.sum_rows.append(variant)
self.other_outputs = outputs[1:]

Expand Down Expand Up @@ -263,15 +249,9 @@ class Call(DataflowOp):
"""

op: Literal["Call"] = "Call"
signature: FunctionType = Field(default_factory=FunctionType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# The constE edge comes after the value inputs
fun_ty = in_types[-1]
assert isinstance(fun_ty, PolyFuncType)
poly_func = cast(PolyFuncType, fun_ty)
assert len(poly_func.params) == 0
self.signature = poly_func.body
func_sig: PolyFuncType
type_args: list[tys.TypeArg]
instantiation: FunctionType

class Config:
# Needed to avoid random '\n's in the pydantic description
Expand All @@ -288,19 +268,18 @@ class Config:
class CallIndirect(DataflowOp):
"""Call a function indirectly.
Like call, but the first input is a standard dataflow graph type."""
Like call, but the first input is a standard dataflow graph type.
"""

op: Literal["CallIndirect"] = "CallIndirect"
signature: FunctionType = Field(default_factory=FunctionType.empty)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
fun_ty = in_types[0]
assert isinstance(fun_ty, PolyFuncType)
poly_func = cast(PolyFuncType, fun_ty)
assert len(poly_func.params) == 0
assert len(poly_func.body.input) == len(in_types) - 1
assert len(poly_func.body.output) == len(out_types)
self.signature = poly_func.body
fun_ty = in_types[0].root
assert isinstance(fun_ty, FunctionType)
assert len(fun_ty.input) == len(in_types) - 1
assert len(fun_ty.output) == len(out_types)
self.signature = fun_ty


class LoadConstant(DataflowOp):
Expand Down Expand Up @@ -343,12 +322,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# First port is a predicate, i.e. a sum of tuple types. We need to unpack
# those into a list of type rows
pred = in_types[0]
if isinstance(pred, tys.UnitSum):
self.sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)]
assert isinstance(pred.root, tys.SumType)
sum = pred.root.root
if isinstance(sum, tys.UnitSum):
self.sum_rows = [[] for _ in range(sum.size)]
else:
assert isinstance(pred, tys.GeneralSum)
assert isinstance(sum, tys.GeneralSum)
self.sum_rows = []
for ty in pred.rows:
for ty in sum.rows:
self.sum_rows.append(ty)
self.other_inputs = list(in_types[1:])
self.outputs = list(out_types)
Expand Down
14 changes: 5 additions & 9 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import sys
from enum import Enum
from typing import Annotated, Any, Literal, Optional, Union
from typing import Annotated, Any, Literal, Union

from pydantic import (
BaseModel,
Expand All @@ -27,6 +27,7 @@ def _json_custom_error_validator(
Used to define named recursive alias types.
"""
return handler(value)
try:
return handler(value)
except ValidationError as err:
Expand All @@ -37,12 +38,7 @@ def _json_custom_error_validator(


ExtensionId = str


class ExtensionSet(RootModel):
"""A set of extensions ids."""

root: Optional[list[ExtensionId]] = Field(default=None)
ExtensionSet = list[ExtensionId]


# --------------------------------------------
Expand Down Expand Up @@ -204,11 +200,11 @@ class FunctionType(BaseModel):
input: "TypeRow" # Value inputs of the function.
output: "TypeRow" # Value outputs of the function.
# The extension requirements which are added by the operation
extension_reqs: "ExtensionSet" = Field(default_factory=list)
extension_reqs: ExtensionSet = Field(default_factory=list)

@classmethod
def empty(cls) -> "FunctionType":
return FunctionType(input=[], output=[], extension_reqs=ExtensionSet([]))
return FunctionType(input=[], output=[], extension_reqs=[])

class Config:
# Needed to avoid random '\n's in the pydantic description
Expand Down
Loading

0 comments on commit c92d312

Please sign in to comment.