From c92d312b941637d3c04fda2096d8dddd5edfda0f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 22 Apr 2024 10:11:11 +0100 Subject: [PATCH] fix: Schema issues --- hugr-py/src/hugr/serialization/ops.py | 69 ++-- hugr-py/src/hugr/serialization/tys.py | 14 +- specification/schema/hugr_schema_v1.json | 385 +++++++++++++++--- .../schema/testing_hugr_schema_v1.json | 28 +- 4 files changed, 378 insertions(+), 118 deletions(-) diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 299013baa..e6da54949 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -1,7 +1,7 @@ import inspect import sys from abc import ABC -from typing import Any, Literal, cast +from typing import Any, Literal from pydantic import BaseModel, Field, RootModel @@ -25,7 +25,7 @@ class BaseOp(ABC, BaseModel): # Parent node index of node the op belongs to, used only at serialization time parent: NodeID - input_extensions: ExtensionSet = Field(default_factory=ExtensionSet) + input_extensions: ExtensionSet | None = Field(default=None) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: """Hook to insert type information from the input and output ports into the @@ -56,14 +56,7 @@ class FuncDefn(BaseOp): op: Literal["FuncDefn"] = "FuncDefn" name: str - signature: PolyFuncType = Field(default_factory=PolyFuncType.empty) - - def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: - assert len(in_types) == 0 - assert len(out_types) == 1 - out = out_types[0] - assert isinstance(out, PolyFuncType) - self.signature = out # TODO: Extensions + signature: PolyFuncType class FuncDecl(BaseOp): @@ -71,14 +64,7 @@ class FuncDecl(BaseOp): op: Literal["FuncDecl"] = "FuncDecl" name: str - signature: PolyFuncType = Field(default_factory=PolyFuncType.empty) - - def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: - assert len(in_types) == 0 - assert len(out_types) == 1 - out = out_types[0] - assert isinstance(out, PolyFuncType) - self.signature = out + signature: PolyFuncType CustomConst = Any # TODO @@ -183,13 +169,13 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: self.inputs = inputs - pred = outputs[0] - assert isinstance(pred, tys.UnitSum | tys.GeneralSum) - if isinstance(pred, tys.UnitSum): - self.sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)] + pred = outputs[0].root + assert isinstance(pred, tys.SumType) + if isinstance(pred.root, tys.UnitSum): + self.sum_rows = [[] for _ in range(pred.root.size)] else: self.sum_rows = [] - for variant in pred.rows: + for variant in pred.root.rows: self.sum_rows.append(variant) self.other_outputs = outputs[1:] @@ -263,15 +249,9 @@ class Call(DataflowOp): """ op: Literal["Call"] = "Call" - signature: FunctionType = Field(default_factory=FunctionType.empty) - - def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: - # The constE edge comes after the value inputs - fun_ty = in_types[-1] - assert isinstance(fun_ty, PolyFuncType) - poly_func = cast(PolyFuncType, fun_ty) - assert len(poly_func.params) == 0 - self.signature = poly_func.body + func_sig: PolyFuncType + type_args: list[tys.TypeArg] + instantiation: FunctionType class Config: # Needed to avoid random '\n's in the pydantic description @@ -288,19 +268,18 @@ class Config: class CallIndirect(DataflowOp): """Call a function indirectly. - Like call, but the first input is a standard dataflow graph type.""" + Like call, but the first input is a standard dataflow graph type. + """ op: Literal["CallIndirect"] = "CallIndirect" signature: FunctionType = Field(default_factory=FunctionType.empty) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: - fun_ty = in_types[0] - assert isinstance(fun_ty, PolyFuncType) - poly_func = cast(PolyFuncType, fun_ty) - assert len(poly_func.params) == 0 - assert len(poly_func.body.input) == len(in_types) - 1 - assert len(poly_func.body.output) == len(out_types) - self.signature = poly_func.body + fun_ty = in_types[0].root + assert isinstance(fun_ty, FunctionType) + assert len(fun_ty.input) == len(in_types) - 1 + assert len(fun_ty.output) == len(out_types) + self.signature = fun_ty class LoadConstant(DataflowOp): @@ -343,12 +322,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: # First port is a predicate, i.e. a sum of tuple types. We need to unpack # those into a list of type rows pred = in_types[0] - if isinstance(pred, tys.UnitSum): - self.sum_rows = [[] for _ in range(cast(tys.UnitSum, pred).size)] + assert isinstance(pred.root, tys.SumType) + sum = pred.root.root + if isinstance(sum, tys.UnitSum): + self.sum_rows = [[] for _ in range(sum.size)] else: - assert isinstance(pred, tys.GeneralSum) + assert isinstance(sum, tys.GeneralSum) self.sum_rows = [] - for ty in pred.rows: + for ty in sum.rows: self.sum_rows.append(ty) self.other_inputs = list(in_types[1:]) self.outputs = list(out_types) diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index d591e3a5a..4058e04d5 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -1,7 +1,7 @@ import inspect import sys from enum import Enum -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, Union from pydantic import ( BaseModel, @@ -27,6 +27,7 @@ def _json_custom_error_validator( Used to define named recursive alias types. """ + return handler(value) try: return handler(value) except ValidationError as err: @@ -37,12 +38,7 @@ def _json_custom_error_validator( ExtensionId = str - - -class ExtensionSet(RootModel): - """A set of extensions ids.""" - - root: Optional[list[ExtensionId]] = Field(default=None) +ExtensionSet = list[ExtensionId] # -------------------------------------------- @@ -204,11 +200,11 @@ class FunctionType(BaseModel): input: "TypeRow" # Value inputs of the function. output: "TypeRow" # Value outputs of the function. # The extension requirements which are added by the operation - extension_reqs: "ExtensionSet" = Field(default_factory=list) + extension_reqs: ExtensionSet = Field(default_factory=list) @classmethod def empty(cls) -> "FunctionType": - return FunctionType(input=[], output=[], extension_reqs=ExtensionSet([])) + return FunctionType(input=[], output=[], extension_reqs=[]) class Config: # Needed to avoid random '\n's in the pydantic description diff --git a/specification/schema/hugr_schema_v1.json b/specification/schema/hugr_schema_v1.json index 3fb8c8961..94ff8f9ef 100644 --- a/specification/schema/hugr_schema_v1.json +++ b/specification/schema/hugr_schema_v1.json @@ -34,7 +34,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "AliasDecl", @@ -147,7 +159,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "CFG", @@ -176,7 +200,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Call", @@ -187,12 +223,25 @@ "title": "Op", "type": "string" }, - "signature": { + "func_sig": { + "$ref": "#/$defs/PolyFuncType" + }, + "type_args": { + "items": { + "$ref": "#/$defs/TypeArg" + }, + "title": "Type Args", + "type": "array" + }, + "instantiation": { "$ref": "#/$defs/FunctionType" } }, "required": [ - "parent" + "parent", + "func_sig", + "type_args", + "instantiation" ], "title": "Call", "type": "object" @@ -205,7 +254,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "CallIndirect", @@ -234,7 +295,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Case", @@ -263,7 +336,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Conditional", @@ -300,7 +385,11 @@ "type": "array" }, "extension_delta": { - "$ref": "#/$defs/ExtensionSet" + "items": { + "type": "string" + }, + "title": "Extension Delta", + "type": "array" } }, "required": [ @@ -317,7 +406,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Const", @@ -348,7 +449,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "CustomOp", @@ -417,7 +530,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "DFG", @@ -446,7 +571,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "DataflowBlock", @@ -482,7 +619,11 @@ "type": "array" }, "extension_delta": { - "$ref": "#/$defs/ExtensionSet" + "items": { + "type": "string" + }, + "title": "Extension Delta", + "type": "array" } }, "required": [ @@ -504,7 +645,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "ExitBlock", @@ -530,22 +683,6 @@ "title": "ExitBlock", "type": "object" }, - "ExtensionSet": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "default": null, - "description": "A set of extensions ids.", - "title": "ExtensionSet" - }, "ExtensionValue": { "description": "An extension constant value, that can check it is of a given [CustomType].", "properties": { @@ -581,7 +718,11 @@ "type": "string" }, "es": { - "$ref": "#/$defs/ExtensionSet" + "items": { + "type": "string" + }, + "title": "Es", + "type": "array" } }, "required": [ @@ -613,7 +754,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "FuncDecl", @@ -634,7 +787,8 @@ }, "required": [ "parent", - "name" + "name", + "signature" ], "title": "FuncDecl", "type": "object" @@ -647,7 +801,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "FuncDefn", @@ -668,7 +834,8 @@ }, "required": [ "parent", - "name" + "name", + "signature" ], "title": "FuncDefn", "type": "object" @@ -700,7 +867,11 @@ "type": "array" }, "extension_reqs": { - "$ref": "#/$defs/ExtensionSet" + "items": { + "type": "string" + }, + "title": "Extension Reqs", + "type": "array" } }, "required": [ @@ -770,7 +941,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Input", @@ -803,7 +986,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Lift", @@ -863,7 +1058,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "LoadConstant", @@ -893,7 +1100,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "MakeTuple", @@ -926,7 +1145,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Module", @@ -952,7 +1183,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Noop", @@ -1167,7 +1410,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Output", @@ -1316,7 +1571,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "Tag", @@ -1379,7 +1646,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "TailLoop", @@ -1677,7 +1956,19 @@ "type": "integer" }, "input_extensions": { - "$ref": "#/$defs/ExtensionSet" + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Input Extensions" }, "op": { "const": "UnpackTuple", diff --git a/specification/schema/testing_hugr_schema_v1.json b/specification/schema/testing_hugr_schema_v1.json index 49240bcfd..d9f70a3d7 100644 --- a/specification/schema/testing_hugr_schema_v1.json +++ b/specification/schema/testing_hugr_schema_v1.json @@ -123,22 +123,6 @@ "title": "CustomTypeArg", "type": "object" }, - "ExtensionSet": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "default": null, - "description": "A set of extensions ids.", - "title": "ExtensionSet" - }, "ExtensionValue": { "description": "An extension constant value, that can check it is of a given [CustomType].", "properties": { @@ -174,7 +158,11 @@ "type": "string" }, "es": { - "$ref": "#/$defs/ExtensionSet" + "items": { + "type": "string" + }, + "title": "Es", + "type": "array" } }, "required": [ @@ -225,7 +213,11 @@ "type": "array" }, "extension_reqs": { - "$ref": "#/$defs/ExtensionSet" + "items": { + "type": "string" + }, + "title": "Extension Reqs", + "type": "array" } }, "required": [