diff --git a/hugr-py/src/hugr/__init__.py b/hugr-py/src/hugr/__init__.py index 7b63396f5..ffd86ff5c 100644 --- a/hugr-py/src/hugr/__init__.py +++ b/hugr-py/src/hugr/__init__.py @@ -5,8 +5,3 @@ # This is updated by our release-please workflow, triggered by this # annotation: x-release-please-version __version__ = "0.2.1" - - -def it_works() -> str: - """Return a string to confirm that the package is installed and working.""" - return "It works!" diff --git a/hugr-py/src/hugr/_cfg.py b/hugr-py/src/hugr/cfg.py similarity index 93% rename from hugr-py/src/hugr/_cfg.py rename to hugr-py/src/hugr/cfg.py index bde255b32..0cec794b1 100644 --- a/hugr-py/src/hugr/_cfg.py +++ b/hugr-py/src/hugr/cfg.py @@ -2,14 +2,14 @@ from dataclasses import dataclass -import hugr._ops as ops - -from ._dfg import _DfBase -from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit -from ._hugr import Hugr, ParentBuilder -from ._node_port import Node, Wire, ToNode -from ._tys import TypeRow, Type -import hugr._val as val +import hugr.ops as ops + +from .dfg import _DfBase +from .exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit +from .hugr import Hugr, ParentBuilder +from .node_port import Node, Wire, ToNode +from .tys import TypeRow, Type +import hugr.val as val class Block(_DfBase[ops.DataflowBlock]): diff --git a/hugr-py/src/hugr/_cond_loop.py b/hugr-py/src/hugr/cond_loop.py similarity index 95% rename from hugr-py/src/hugr/_cond_loop.py rename to hugr-py/src/hugr/cond_loop.py index 6433a0f21..a1ac42830 100644 --- a/hugr-py/src/hugr/_cond_loop.py +++ b/hugr-py/src/hugr/cond_loop.py @@ -2,13 +2,13 @@ from dataclasses import dataclass -import hugr._ops as ops +import hugr.ops as ops -from ._dfg import _DfBase -from ._hugr import Hugr, ParentBuilder -from ._node_port import Node, Wire, ToNode +from .dfg import _DfBase +from .hugr import Hugr, ParentBuilder +from .node_port import Node, Wire, ToNode -from ._tys import Sum, TypeRow +from .tys import Sum, TypeRow class Case(_DfBase[ops.Case]): diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/dfg.py similarity index 94% rename from hugr-py/src/hugr/_dfg.py rename to hugr-py/src/hugr/dfg.py index 8fd42f846..99091b6c6 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/dfg.py @@ -10,9 +10,9 @@ from typing_extensions import Self -import hugr._ops as ops -import hugr._val as val -from hugr._tys import ( +import hugr.ops as ops +import hugr.val as val +from hugr.tys import ( Type, TypeRow, get_first_sum, @@ -23,13 +23,13 @@ ExtensionSet, ) -from ._exceptions import NoSiblingAncestor -from ._hugr import Hugr, ParentBuilder -from ._node_port import Node, OutPort, Wire, ToNode +from .exceptions import NoSiblingAncestor +from .hugr import Hugr, ParentBuilder +from .node_port import Node, OutPort, Wire, ToNode if TYPE_CHECKING: - from ._cfg import Cfg - from ._cond_loop import Conditional, If, TailLoop + from .cfg import Cfg + from .cond_loop import Conditional, If, TailLoop DP = TypeVar("DP", bound=ops.DfParentOp) @@ -96,7 +96,7 @@ def add_nested( self, *args: Wire, ) -> Dfg: - from ._dfg import Dfg + from .dfg import Dfg parent_op = ops.DFG(self._wire_types(args)) dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node) @@ -110,7 +110,7 @@ def add_cfg( self, *args: Wire, ) -> Cfg: - from ._cfg import Cfg + from .cfg import Cfg cfg = Cfg.new_nested(self._wire_types(args), self.hugr, self.parent_node) self._wire_up(cfg.parent_node, args) @@ -120,7 +120,7 @@ def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node: return self._insert_nested_impl(cfg, *args) def add_conditional(self, cond: Wire, *args: Wire) -> Conditional: - from ._cond_loop import Conditional + from .cond_loop import Conditional args = (cond, *args) (sum_, other_inputs) = get_first_sum(self._wire_types(args)) @@ -132,7 +132,7 @@ def insert_conditional(self, cond: Conditional, *args: Wire) -> Node: return self._insert_nested_impl(cond, *args) def add_if(self, cond: Wire, *args: Wire) -> If: - from ._cond_loop import If + from .cond_loop import If conditional = self.add_conditional(cond, *args) return If(conditional.add_case(1)) @@ -140,7 +140,7 @@ def add_if(self, cond: Wire, *args: Wire) -> If: def add_tail_loop( self, just_inputs: Sequence[Wire], rest: Sequence[Wire] ) -> TailLoop: - from ._cond_loop import TailLoop + from .cond_loop import TailLoop just_input_types = self._wire_types(just_inputs) rest_types = self._wire_types(rest) diff --git a/hugr-py/src/hugr/_exceptions.py b/hugr-py/src/hugr/exceptions.py similarity index 100% rename from hugr-py/src/hugr/_exceptions.py rename to hugr-py/src/hugr/exceptions.py diff --git a/hugr-py/src/hugr/_function.py b/hugr-py/src/hugr/function.py similarity index 87% rename from hugr-py/src/hugr/_function.py rename to hugr-py/src/hugr/function.py index c1c54984c..2e698c5d1 100644 --- a/hugr-py/src/hugr/_function.py +++ b/hugr-py/src/hugr/function.py @@ -2,13 +2,13 @@ from dataclasses import dataclass -import hugr._ops as ops -import hugr._val as val +import hugr.ops as ops +import hugr.val as val -from ._dfg import _DfBase -from hugr._node_port import Node -from ._hugr import Hugr -from ._tys import TypeRow, TypeParam, PolyFuncType, Type, TypeBound +from .dfg import _DfBase +from hugr.node_port import Node +from .hugr import Hugr +from .tys import TypeRow, TypeParam, PolyFuncType, Type, TypeBound @dataclass diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/hugr.py similarity index 97% rename from hugr-py/src/hugr/_hugr.py rename to hugr-py/src/hugr/hugr.py index 7ee6831fa..7fb9bc70c 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -13,15 +13,15 @@ ) -from hugr._ops import Op, DataflowOp, Const, Call -from hugr._tys import Type, Kind, ValueKind -from hugr._val import Value -from hugr._node_port import Direction, InPort, OutPort, ToNode, Node, _SubPort +from hugr.ops import Op, DataflowOp, Const, Call +from hugr.tys import Type, Kind, ValueKind +from hugr.val import Value +from hugr.node_port import Direction, InPort, OutPort, ToNode, Node, _SubPort from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.serial_hugr import SerialHugr from hugr.utils import BiMap -from ._exceptions import ParentBeforeChild +from .exceptions import ParentBeforeChild @dataclass() diff --git a/hugr-py/src/hugr/_node_port.py b/hugr-py/src/hugr/node_port.py similarity index 100% rename from hugr-py/src/hugr/_node_port.py rename to hugr-py/src/hugr/node_port.py diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/ops.py similarity index 99% rename from hugr-py/src/hugr/_ops.py rename to hugr-py/src/hugr/ops.py index a2b68153c..07ab362dc 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/ops.py @@ -5,13 +5,13 @@ from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops from hugr.utils import ser_it -import hugr._tys as tys -from hugr._node_port import Node, InPort, OutPort, Wire -import hugr._val as val -from ._exceptions import IncompleteOp +import hugr.tys as tys +from hugr.node_port import Node, InPort, OutPort, Wire +import hugr.val as val +from .exceptions import IncompleteOp if TYPE_CHECKING: - from hugr._hugr import Hugr + from hugr.hugr import Hugr @dataclass @@ -35,7 +35,7 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind: ... def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type: - from hugr._hugr import Direction + from hugr.node_port import Direction if port.direction == Direction.INCOMING: return sig.input[port.offset] diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index b0bcfb1ab..65706b047 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -6,7 +6,7 @@ from pydantic import Field, RootModel, ConfigDict -from . import tys +from . import tys as stys from .tys import ( ExtensionId, ExtensionSet, @@ -44,7 +44,7 @@ def display_name(self) -> str: return self.__class__.__name__ @abstractmethod - def deserialize(self) -> _ops.Op: + def deserialize(self) -> ops.Op: """Deserializes the model into the corresponding Op.""" @@ -58,8 +58,8 @@ class Module(BaseOp): op: Literal["Module"] = "Module" - def deserialize(self) -> _ops.Module: - return _ops.Module() + def deserialize(self) -> ops.Module: + return ops.Module() class FuncDefn(BaseOp): @@ -70,9 +70,9 @@ class FuncDefn(BaseOp): name: str signature: PolyFuncType - def deserialize(self) -> _ops.FuncDefn: + def deserialize(self) -> ops.FuncDefn: poly_func = self.signature.deserialize() - return _ops.FuncDefn( + return ops.FuncDefn( self.name, inputs=poly_func.body.input, _outputs=poly_func.body.output ) @@ -84,8 +84,8 @@ class FuncDecl(BaseOp): name: str signature: PolyFuncType - def deserialize(self) -> _ops.FuncDecl: - return _ops.FuncDecl(self.name, self.signature.deserialize()) + def deserialize(self) -> ops.FuncDecl: + return ops.FuncDecl(self.name, self.signature.deserialize()) class CustomConst(ConfiguredBaseModel): @@ -95,7 +95,7 @@ class CustomConst(ConfiguredBaseModel): class BaseValue(ABC, ConfiguredBaseModel): @abstractmethod - def deserialize(self) -> _val.Value: ... + def deserialize(self) -> val.Value: ... class ExtensionValue(BaseValue): @@ -106,8 +106,8 @@ class ExtensionValue(BaseValue): typ: Type value: CustomConst - def deserialize(self) -> _val.Value: - return _val.Extension(self.value.c, self.typ.deserialize(), self.value.v) + def deserialize(self) -> val.Value: + return val.Extension(self.value.c, self.typ.deserialize(), self.value.v) class FunctionValue(BaseValue): @@ -116,12 +116,12 @@ class FunctionValue(BaseValue): v: Literal["Function"] = Field(default="Function", title="ValueTag") hugr: Any - def deserialize(self) -> _val.Value: - from hugr._hugr import Hugr + def deserialize(self) -> val.Value: + from hugr.hugr import Hugr from hugr.serialization.serial_hugr import SerialHugr # pydantic stores the serialized dictionary because of the "Any" annotation - return _val.Function(Hugr.from_serial(SerialHugr(**self.hugr))) + return val.Function(Hugr.from_serial(SerialHugr(**self.hugr))) class TupleValue(BaseValue): @@ -130,8 +130,8 @@ class TupleValue(BaseValue): v: Literal["Tuple"] = Field(default="Tuple", title="ValueTag") vs: list["Value"] - def deserialize(self) -> _val.Value: - return _val.Tuple(deser_it((v.root for v in self.vs))) + def deserialize(self) -> val.Value: + return val.Tuple(deser_it((v.root for v in self.vs))) class SumValue(BaseValue): @@ -153,8 +153,8 @@ class SumValue(BaseValue): } ) - def deserialize(self) -> _val.Value: - return _val.Sum( + def deserialize(self) -> val.Value: + return val.Sum( self.tag, self.typ.deserialize(), deser_it((v.root for v in self.vs)) ) @@ -168,7 +168,7 @@ class Value(RootModel): model_config = ConfigDict(json_schema_extra={"required": ["v"]}) - def deserialize(self) -> _val.Value: + def deserialize(self) -> val.Value: return self.root.deserialize() @@ -178,8 +178,8 @@ class Const(BaseOp): op: Literal["Const"] = "Const" v: Value = Field() - def deserialize(self) -> _ops.Const: - return _ops.Const(self.v.deserialize()) + def deserialize(self) -> ops.Const: + return ops.Const(self.v.deserialize()) # ----------------------------------------------- @@ -204,8 +204,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: self.inputs = inputs pred = outputs[0].root - assert isinstance(pred, tys.SumType) - if isinstance(pred.root, tys.UnitSum): + assert isinstance(pred, stys.SumType) + if isinstance(pred.root, stys.UnitSum): self.sum_rows = [[] for _ in range(pred.root.size)] else: self.sum_rows = [] @@ -215,10 +215,10 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: # Needed to avoid random '\n's in the pydantic description - def deserialize(self) -> _ops.DataflowBlock: - return _ops.DataflowBlock( + def deserialize(self) -> ops.DataflowBlock: + return ops.DataflowBlock( inputs=deser_it(self.inputs), - _sum=_tys.Sum([deser_it(r) for r in self.sum_rows]), + _sum=tys.Sum([deser_it(r) for r in self.sum_rows]), _other_outputs=deser_it(self.other_outputs), ) @@ -243,8 +243,8 @@ class ExitBlock(BaseOp): } ) - def deserialize(self) -> _ops.ExitBlock: - return _ops.ExitBlock(deser_it(self.cfg_outputs)) + def deserialize(self) -> ops.ExitBlock: + return ops.ExitBlock(deser_it(self.cfg_outputs)) # --------------------------------------------- @@ -266,8 +266,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(in_types) == 0 self.types = list(out_types) - def deserialize(self) -> _ops.Input: - return _ops.Input(types=[t.deserialize() for t in self.types]) + def deserialize(self) -> ops.Input: + return ops.Input(types=[t.deserialize() for t in self.types]) class Output(DataflowOp): @@ -280,8 +280,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(out_types) == 0 self.types = list(in_types) - def deserialize(self) -> _ops.Output: - return _ops.Output(deser_it(self.types)) + def deserialize(self) -> ops.Output: + return ops.Output(deser_it(self.types)) class Call(DataflowOp): @@ -295,7 +295,7 @@ class Call(DataflowOp): op: Literal["Call"] = "Call" func_sig: PolyFuncType - type_args: list[tys.TypeArg] + type_args: list[stys.TypeArg] instantiation: FunctionType model_config = ConfigDict( @@ -310,8 +310,8 @@ class Call(DataflowOp): } ) - def deserialize(self) -> _ops.Call: - return _ops.Call( + def deserialize(self) -> ops.Call: + return ops.Call( self.func_sig.deserialize(), self.instantiation.deserialize(), deser_it(self.type_args), @@ -334,8 +334,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(fun_ty.output) == len(out_types) self.signature = fun_ty - def deserialize(self) -> _ops.CallIndirectDef: - return _ops.CallIndirectDef(self.signature.deserialize()) + def deserialize(self) -> ops.CallIndirectDef: + return ops.CallIndirectDef(self.signature.deserialize()) class LoadConstant(DataflowOp): @@ -344,8 +344,8 @@ class LoadConstant(DataflowOp): op: Literal["LoadConstant"] = "LoadConstant" datatype: Type - def deserialize(self) -> _ops.LoadConst: - return _ops.LoadConst(self.datatype.deserialize()) + def deserialize(self) -> ops.LoadConst: + return ops.LoadConst(self.datatype.deserialize()) class LoadFunction(DataflowOp): @@ -353,17 +353,17 @@ class LoadFunction(DataflowOp): op: Literal["LoadFunction"] = "LoadFunction" func_sig: PolyFuncType - type_args: list[tys.TypeArg] + type_args: list[stys.TypeArg] signature: FunctionType - def deserialize(self) -> _ops.LoadFunc: + def deserialize(self) -> ops.LoadFunc: signature = self.signature.deserialize() assert len(signature.input) == 0 (f_ty,) = signature.output assert isinstance( - f_ty, _tys.FunctionType + f_ty, tys.FunctionType ), "Expected single funciton type output" - return _ops.LoadFunc( + return ops.LoadFunc( self.func_sig.deserialize(), f_ty, deser_it(self.type_args), @@ -381,9 +381,9 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) - def deserialize(self) -> _ops.DFG: + def deserialize(self) -> ops.DFG: sig = self.signature.deserialize() - return _ops.DFG(sig.input, sig.output, sig.extension_reqs) + return ops.DFG(sig.input, sig.output, sig.extension_reqs) # ------------------------------------------------ @@ -407,21 +407,21 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: # First port is a predicate, i.e. a sum of tuple types. We need to unpack # those into a list of type rows pred = in_types[0] - assert isinstance(pred.root, tys.SumType) + assert isinstance(pred.root, stys.SumType) sum = pred.root.root - if isinstance(sum, tys.UnitSum): + if isinstance(sum, stys.UnitSum): self.sum_rows = [[] for _ in range(sum.size)] else: - assert isinstance(sum, tys.GeneralSum) + assert isinstance(sum, stys.GeneralSum) self.sum_rows = [] for ty in sum.rows: self.sum_rows.append(ty) self.other_inputs = list(in_types[1:]) self.outputs = list(out_types) - def deserialize(self) -> _ops.Conditional: - return _ops.Conditional( - _tys.Sum([deser_it(r) for r in self.sum_rows]), + def deserialize(self) -> ops.Conditional: + return ops.Conditional( + tys.Sum([deser_it(r) for r in self.sum_rows]), deser_it(self.other_inputs), deser_it(self.outputs), ) @@ -435,13 +435,13 @@ class Case(BaseOp): signature: FunctionType = Field(default_factory=FunctionType.empty) def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: - self.signature = tys.FunctionType( + self.signature = stys.FunctionType( input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) - def deserialize(self) -> _ops.Case: + def deserialize(self) -> ops.Case: sig = self.signature.deserialize() - return _ops.Case(inputs=sig.input, _outputs=sig.output) + return ops.Case(inputs=sig.input, _outputs=sig.output) class TailLoop(DataflowOp): @@ -460,8 +460,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: # self.just_outputs = list(out_types) self.rest = list(in_types) - def deserialize(self) -> _ops.TailLoop: - return _ops.TailLoop( + def deserialize(self) -> ops.TailLoop: + return ops.TailLoop( just_inputs=deser_it(self.just_inputs), _just_outputs=deser_it(self.just_outputs), rest=deser_it(self.rest), @@ -480,9 +480,9 @@ def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) - def deserialize(self) -> _ops.CFG: + def deserialize(self) -> ops.CFG: sig = self.signature.deserialize() - return _ops.CFG(inputs=sig.input, _outputs=sig.output) + return ops.CFG(inputs=sig.input, _outputs=sig.output) ControlFlowOp = Conditional | TailLoop | CFG @@ -495,18 +495,18 @@ class CustomOp(DataflowOp): op: Literal["CustomOp"] = "CustomOp" extension: ExtensionId op_name: str - signature: tys.FunctionType = Field(default_factory=tys.FunctionType.empty) + signature: stys.FunctionType = Field(default_factory=stys.FunctionType.empty) description: str = "" - args: list[tys.TypeArg] = Field(default_factory=list) + args: list[stys.TypeArg] = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: - self.signature = tys.FunctionType(input=list(in_types), output=list(out_types)) + self.signature = stys.FunctionType(input=list(in_types), output=list(out_types)) def display_name(self) -> str: return self.op_name - def deserialize(self) -> _ops.Custom: - return _ops.Custom( + def deserialize(self) -> ops.Custom: + return ops.Custom( extension=self.extension, op_name=self.op_name, signature=self.signature.deserialize(), @@ -536,8 +536,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert in_types[0] == out_types[0] self.ty = in_types[0] - def deserialize(self) -> _ops.NoopDef: - return _ops.NoopDef(self.ty.deserialize()) + def deserialize(self) -> ops.NoopDef: + return ops.NoopDef(self.ty.deserialize()) class MakeTuple(DataflowOp): @@ -552,8 +552,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: in_types = [] self.tys = list(in_types) - def deserialize(self) -> _ops.MakeTupleDef: - return _ops.MakeTupleDef(deser_it(self.tys)) + def deserialize(self) -> ops.MakeTupleDef: + return ops.MakeTupleDef(deser_it(self.tys)) class UnpackTuple(DataflowOp): @@ -565,8 +565,8 @@ class UnpackTuple(DataflowOp): def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(out_types) - def deserialize(self) -> _ops.UnpackTupleDef: - return _ops.UnpackTupleDef(deser_it(self.tys)) + def deserialize(self) -> ops.UnpackTupleDef: + return ops.UnpackTupleDef(deser_it(self.tys)) class Tag(DataflowOp): @@ -576,10 +576,10 @@ class Tag(DataflowOp): tag: int # The variant to create. variants: list[TypeRow] # The variants of the sum type. - def deserialize(self) -> _ops.Tag: - return _ops.Tag( + def deserialize(self) -> ops.Tag: + return ops.Tag( tag=self.tag, - sum_ty=_tys.Sum([deser_it(v) for v in self.variants]), + sum_ty=tys.Sum([deser_it(v) for v in self.variants]), ) @@ -590,8 +590,8 @@ class Lift(DataflowOp): type_row: TypeRow new_extension: ExtensionId - def deserialize(self) -> _ops.Lift: - return _ops.Lift( + def deserialize(self) -> ops.Lift: + return ops.Lift( _type_row=deser_it(self.type_row), new_extension=self.new_extension, ) @@ -602,8 +602,8 @@ class AliasDecl(BaseOp): name: str bound: TypeBound - def deserialize(self) -> _ops.AliasDecl: - return _ops.AliasDecl(self.name, self.bound) + def deserialize(self) -> ops.AliasDecl: + return ops.AliasDecl(self.name, self.bound) class AliasDefn(BaseOp): @@ -611,8 +611,8 @@ class AliasDefn(BaseOp): name: str definition: Type - def deserialize(self) -> _ops.AliasDefn: - return _ops.AliasDefn(self.name, self.definition.deserialize()) + def deserialize(self) -> ops.AliasDefn: + return ops.AliasDefn(self.name, self.definition.deserialize()) class OpType(RootModel): @@ -683,6 +683,6 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): tys_model_rebuild(dict(classes)) # needed to avoid circular imports -from hugr import _ops # noqa: E402 -from hugr import _val # noqa: E402 -from hugr import _tys # noqa: E402 +from hugr import ops # noqa: E402 +from hugr import val # noqa: E402 +from hugr import tys # noqa: E402 diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index 7cd6a3b0d..053deefcd 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -63,54 +63,54 @@ def update_model_config(cls, config: ConfigDict): class BaseTypeParam(ABC, ConfiguredBaseModel): @abstractmethod - def deserialize(self) -> _tys.TypeParam: ... + def deserialize(self) -> tys.TypeParam: ... class TypeTypeParam(BaseTypeParam): tp: Literal["Type"] = "Type" b: "TypeBound" - def deserialize(self) -> _tys.TypeTypeParam: - return _tys.TypeTypeParam(bound=self.b) + def deserialize(self) -> tys.TypeTypeParam: + return tys.TypeTypeParam(bound=self.b) class BoundedNatParam(BaseTypeParam): tp: Literal["BoundedNat"] = "BoundedNat" bound: int | None - def deserialize(self) -> _tys.BoundedNatParam: - return _tys.BoundedNatParam(upper_bound=self.bound) + def deserialize(self) -> tys.BoundedNatParam: + return tys.BoundedNatParam(upper_bound=self.bound) class OpaqueParam(BaseTypeParam): tp: Literal["Opaque"] = "Opaque" ty: "Opaque" - def deserialize(self) -> _tys.OpaqueParam: - return _tys.OpaqueParam(ty=self.ty.deserialize()) + def deserialize(self) -> tys.OpaqueParam: + return tys.OpaqueParam(ty=self.ty.deserialize()) class ListParam(BaseTypeParam): tp: Literal["List"] = "List" param: "TypeParam" - def deserialize(self) -> _tys.ListParam: - return _tys.ListParam(param=self.param.deserialize()) + def deserialize(self) -> tys.ListParam: + return tys.ListParam(param=self.param.deserialize()) class TupleParam(BaseTypeParam): tp: Literal["Tuple"] = "Tuple" params: list["TypeParam"] - def deserialize(self) -> _tys.TupleParam: - return _tys.TupleParam(params=deser_it(self.params)) + def deserialize(self) -> tys.TupleParam: + return tys.TupleParam(params=deser_it(self.params)) class ExtensionsParam(BaseTypeParam): tp: Literal["Extensions"] = "Extensions" - def deserialize(self) -> _tys.ExtensionsParam: - return _tys.ExtensionsParam() + def deserialize(self) -> tys.ExtensionsParam: + return tys.ExtensionsParam() class TypeParam(RootModel): @@ -128,7 +128,7 @@ class TypeParam(RootModel): model_config = ConfigDict(json_schema_extra={"required": ["tp"]}) - def deserialize(self) -> _tys.TypeParam: + def deserialize(self) -> tys.TypeParam: return self.root.deserialize() @@ -139,23 +139,23 @@ def deserialize(self) -> _tys.TypeParam: class BaseTypeArg(ABC, ConfiguredBaseModel): @abstractmethod - def deserialize(self) -> _tys.TypeArg: ... + def deserialize(self) -> tys.TypeArg: ... class TypeTypeArg(BaseTypeArg): tya: Literal["Type"] = "Type" ty: "Type" - def deserialize(self) -> _tys.TypeTypeArg: - return _tys.TypeTypeArg(ty=self.ty.deserialize()) + def deserialize(self) -> tys.TypeTypeArg: + return tys.TypeTypeArg(ty=self.ty.deserialize()) class BoundedNatArg(BaseTypeArg): tya: Literal["BoundedNat"] = "BoundedNat" n: int - def deserialize(self) -> _tys.BoundedNatArg: - return _tys.BoundedNatArg(n=self.n) + def deserialize(self) -> tys.BoundedNatArg: + return tys.BoundedNatArg(n=self.n) class OpaqueArg(BaseTypeArg): @@ -163,24 +163,24 @@ class OpaqueArg(BaseTypeArg): typ: "Opaque" value: Any - def deserialize(self) -> _tys.OpaqueArg: - return _tys.OpaqueArg(ty=self.typ.deserialize(), value=self.value) + def deserialize(self) -> tys.OpaqueArg: + return tys.OpaqueArg(ty=self.typ.deserialize(), value=self.value) class SequenceArg(BaseTypeArg): tya: Literal["Sequence"] = "Sequence" elems: list["TypeArg"] - def deserialize(self) -> _tys.SequenceArg: - return _tys.SequenceArg(elems=deser_it(self.elems)) + def deserialize(self) -> tys.SequenceArg: + return tys.SequenceArg(elems=deser_it(self.elems)) class ExtensionsArg(BaseTypeArg): tya: Literal["Extensions"] = "Extensions" es: ExtensionSet - def deserialize(self) -> _tys.ExtensionsArg: - return _tys.ExtensionsArg(extensions=self.es) + def deserialize(self) -> tys.ExtensionsArg: + return tys.ExtensionsArg(extensions=self.es) class VariableArg(BaseTypeArg): @@ -188,8 +188,8 @@ class VariableArg(BaseTypeArg): idx: int cached_decl: TypeParam - def deserialize(self) -> _tys.VariableArg: - return _tys.VariableArg(idx=self.idx, param=self.cached_decl.deserialize()) + def deserialize(self) -> tys.VariableArg: + return tys.VariableArg(idx=self.idx, param=self.cached_decl.deserialize()) class TypeArg(RootModel): @@ -207,7 +207,7 @@ class TypeArg(RootModel): model_config = ConfigDict(json_schema_extra={"required": ["tya"]}) - def deserialize(self) -> _tys.TypeArg: + def deserialize(self) -> tys.TypeArg: return self.root.deserialize() @@ -218,7 +218,7 @@ def deserialize(self) -> _tys.TypeArg: class BaseType(ABC, ConfiguredBaseModel): @abstractmethod - def deserialize(self) -> _tys.Type: ... + def deserialize(self) -> tys.Type: ... class MultiContainer(BaseType): @@ -231,8 +231,8 @@ class Array(MultiContainer): t: Literal["Array"] = "Array" len: int - def deserialize(self) -> _tys.Array: - return _tys.Array(ty=self.ty.deserialize(), size=self.len) + def deserialize(self) -> tys.Array: + return tys.Array(ty=self.ty.deserialize(), size=self.len) class UnitSum(BaseType): @@ -242,8 +242,8 @@ class UnitSum(BaseType): s: Literal["Unit"] = "Unit" size: int - def deserialize(self) -> _tys.UnitSum: - return _tys.UnitSum(size=self.size) + def deserialize(self) -> tys.UnitSum: + return tys.UnitSum(size=self.size) class GeneralSum(BaseType): @@ -253,8 +253,8 @@ class GeneralSum(BaseType): s: Literal["General"] = "General" rows: list["TypeRow"] - def deserialize(self) -> _tys.Sum: - return _tys.Sum(variant_rows=[[t.deserialize() for t in r] for r in self.rows]) + def deserialize(self) -> tys.Sum: + return tys.Sum(variant_rows=[[t.deserialize() for t in r] for r in self.rows]) class SumType(RootModel): @@ -267,7 +267,7 @@ def t(self) -> str: model_config = ConfigDict(json_schema_extra={"required": ["s"]}) - def deserialize(self) -> _tys.Sum | _tys.UnitSum: + def deserialize(self) -> tys.Sum | tys.UnitSum: return self.root.deserialize() @@ -283,8 +283,8 @@ class Variable(BaseType): i: int b: "TypeBound" - def deserialize(self) -> _tys.Variable: - return _tys.Variable(idx=self.i, bound=self.b) + def deserialize(self) -> tys.Variable: + return tys.Variable(idx=self.i, bound=self.b) class RowVar(BaseType): @@ -295,8 +295,8 @@ class RowVar(BaseType): i: int b: "TypeBound" - def deserialize(self) -> _tys.RowVariable: - return _tys.RowVariable(idx=self.i, bound=self.b) + def deserialize(self) -> tys.RowVariable: + return tys.RowVariable(idx=self.i, bound=self.b) class USize(BaseType): @@ -304,8 +304,8 @@ class USize(BaseType): t: Literal["I"] = "I" - def deserialize(self) -> _tys.USize: - return _tys.USize() + def deserialize(self) -> tys.USize: + return tys.USize() class FunctionType(BaseType): @@ -323,8 +323,8 @@ class FunctionType(BaseType): def empty(cls) -> "FunctionType": return FunctionType(input=[], output=[], extension_reqs=[]) - def deserialize(self) -> _tys.FunctionType: - return _tys.FunctionType( + def deserialize(self) -> tys.FunctionType: + return tys.FunctionType( input=deser_it(self.input), output=deser_it(self.output), extension_reqs=self.extension_reqs, @@ -357,8 +357,8 @@ class PolyFuncType(BaseType): def empty(cls) -> "PolyFuncType": return PolyFuncType(params=[], body=FunctionType.empty()) - def deserialize(self) -> _tys.PolyFuncType: - return _tys.PolyFuncType( + def deserialize(self) -> tys.PolyFuncType: + return tys.PolyFuncType( params=deser_it(self.params), body=self.body.deserialize(), ) @@ -400,8 +400,8 @@ class Opaque(BaseType): args: list[TypeArg] bound: TypeBound - def deserialize(self) -> _tys.Opaque: - return _tys.Opaque( + def deserialize(self) -> tys.Opaque: + return tys.Opaque( extension=self.extension, id=self.id, args=deser_it(self.args), @@ -416,8 +416,8 @@ class Alias(BaseType): bound: TypeBound name: str - def deserialize(self) -> _tys.Alias: - return _tys.Alias(name=self.name, bound=self.bound) + def deserialize(self) -> tys.Alias: + return tys.Alias(name=self.name, bound=self.bound) # ---------------------------------------------- @@ -430,8 +430,8 @@ class Qubit(BaseType): t: Literal["Q"] = "Q" - def deserialize(self) -> _tys.QubitDef: - return _tys.Qubit + def deserialize(self) -> tys.QubitDef: + return tys.Qubit class Type(RootModel): @@ -453,7 +453,7 @@ class Type(RootModel): model_config = ConfigDict(json_schema_extra={"required": ["t"]}) - def deserialize(self) -> _tys.Type: + def deserialize(self) -> tys.Type: return self.root.deserialize() @@ -487,4 +487,4 @@ def model_rebuild( model_rebuild(dict(classes)) -from hugr import _tys # noqa: E402 # needed to avoid circular imports +from hugr import tys # noqa: E402 # needed to avoid circular imports diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/tys.py similarity index 100% rename from hugr-py/src/hugr/_tys.py rename to hugr-py/src/hugr/tys.py diff --git a/hugr-py/src/hugr/_val.py b/hugr-py/src/hugr/val.py similarity index 97% rename from hugr-py/src/hugr/_val.py rename to hugr-py/src/hugr/val.py index d2a6277eb..deb898c09 100644 --- a/hugr-py/src/hugr/_val.py +++ b/hugr-py/src/hugr/val.py @@ -3,11 +3,11 @@ from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING import hugr.serialization.ops as sops import hugr.serialization.tys as stys -import hugr._tys as tys +import hugr.tys as tys from hugr.utils import ser_it if TYPE_CHECKING: - from hugr._hugr import Hugr + from hugr.hugr import Hugr @runtime_checkable diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py new file mode 100644 index 000000000..33f792d0b --- /dev/null +++ b/hugr-py/tests/conftest.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import subprocess +import os +import pathlib +from hugr.node_port import Wire + +from hugr.hugr import Hugr +from hugr.ops import Custom, Command +from hugr.serialization import SerialHugr +import hugr.tys as tys +import hugr.val as val +import json + + +def int_t(width: int) -> tys.Opaque: + return tys.Opaque( + extension="arithmetic.int.types", + id="int", + args=[tys.BoundedNatArg(n=width)], + bound=tys.TypeBound.Eq, + ) + + +INT_T = int_t(5) + + +@dataclass +class IntVal(val.ExtensionValue): + v: int + + def to_value(self) -> val.Extension: + return val.Extension("int", INT_T, self.v) + + +@dataclass +class LogicOps(Custom): + extension: tys.ExtensionId = "logic" + + +# TODO get from YAML +@dataclass +class NotDef(LogicOps): + num_out: int | None = 1 + op_name: str = "Not" + signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool]) + + def __call__(self, a: Wire) -> Command: + return super().__call__(a) + + +Not = NotDef() + + +@dataclass +class QuantumOps(Custom): + extension: tys.ExtensionId = "tket2.quantum" + + +@dataclass +class OneQbGate(QuantumOps): + op_name: str + num_out: int | None = 1 + signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) + + def __call__(self, q: Wire) -> Command: + return super().__call__(q) + + +H = OneQbGate("H") + + +@dataclass +class MeasureDef(QuantumOps): + op_name: str = "Measure" + num_out: int | None = 2 + signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) + + def __call__(self, q: Wire) -> Command: + return super().__call__(q) + + +Measure = MeasureDef() + + +@dataclass +class IntOps(Custom): + extension: tys.ExtensionId = "arithmetic.int" + + +ARG_5 = tys.BoundedNatArg(n=5) + + +@dataclass +class DivModDef(IntOps): + num_out: int | None = 2 + extension: tys.ExtensionId = "arithmetic.int" + op_name: str = "idivmod_u" + signature: tys.FunctionType = field( + default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) + ) + args: list[tys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) + + +DivMod = DivModDef() + + +def validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): + workspace_dir = pathlib.Path(__file__).parent.parent.parent + # use the HUGR_BIN environment variable if set, otherwise use the debug build + bin_loc = os.environ.get("HUGR_BIN", str(workspace_dir / "target/debug/hugr")) + cmd = [bin_loc, "-"] + + if mermaid: + cmd.append("--mermaid") + serial = h.to_serial().to_json() + subprocess.run(cmd, check=True, input=serial.encode()) + + if roundtrip: + h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial))) + assert serial == h2.to_serial().to_json() diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index 269c48285..ddd3b9c32 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -1,9 +1,9 @@ -from hugr._cfg import Cfg -import hugr._tys as tys -import hugr._val as val -from hugr._dfg import Dfg -import hugr._ops as ops -from .test_hugr_build import _validate, INT_T, DivMod, IntVal +from hugr.cfg import Cfg +import hugr.tys as tys +import hugr.val as val +from hugr.dfg import Dfg +import hugr.ops as ops +from .conftest import validate, INT_T, DivMod, IntVal def build_basic_cfg(cfg: Cfg) -> None: @@ -16,7 +16,7 @@ def build_basic_cfg(cfg: Cfg) -> None: def test_basic_cfg() -> None: cfg = Cfg([tys.Bool]) build_basic_cfg(cfg) - _validate(cfg.hugr) + validate(cfg.hugr) def test_branch() -> None: @@ -34,7 +34,7 @@ def test_branch() -> None: cfg.branch_exit(middle_1[0]) cfg.branch_exit(middle_2[0]) - _validate(cfg.hugr) + validate(cfg.hugr) def test_nested_cfg() -> None: @@ -45,7 +45,7 @@ def test_nested_cfg() -> None: build_basic_cfg(cfg) dfg.set_outputs(cfg) - _validate(dfg.hugr) + validate(dfg.hugr) def test_dom_edge() -> None: @@ -64,7 +64,7 @@ def test_dom_edge() -> None: cfg.branch_exit(middle_1[0]) cfg.branch_exit(middle_2[0]) - _validate(cfg.hugr) + validate(cfg.hugr) def test_asymm_types() -> None: @@ -86,7 +86,7 @@ def test_asymm_types() -> None: cfg.branch_exit(entry[1]) cfg.branch_exit(middle[0]) - _validate(cfg.hugr) + validate(cfg.hugr) # TODO loop diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 24381e06c..75791317b 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -1,10 +1,10 @@ -from hugr._cond_loop import Conditional, ConditionalError, TailLoop -from hugr._dfg import Dfg -import hugr._tys as tys -import hugr._ops as ops -import hugr._val as val +from hugr.cond_loop import Conditional, ConditionalError, TailLoop +from hugr.dfg import Dfg +import hugr.tys as tys +import hugr.ops as ops +import hugr.val as val import pytest -from .test_hugr_build import INT_T, _validate, IntVal, H, Measure +from .conftest import INT_T, validate, IntVal, H, Measure SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]]) @@ -25,7 +25,7 @@ def build_cond(h: Conditional) -> None: def test_cond() -> None: h = Conditional(SUM_T, [tys.Bool]) build_cond(h) - _validate(h.hugr) + validate(h.hugr) def test_nested_cond() -> None: @@ -35,7 +35,7 @@ def test_nested_cond() -> None: cond = h.add_conditional(tagged_q, h.load(val.TRUE)) build_cond(cond) h.set_outputs(*cond[:2]) - _validate(h.hugr) + validate(h.hugr) # build then insert con = Conditional(SUM_T, [tys.Bool]) @@ -46,7 +46,7 @@ def test_nested_cond() -> None: tagged_q = h.add(ops.Tag(0, SUM_T)(q)) cond_n = h.insert_conditional(con, tagged_q, h.load(val.TRUE)) h.set_outputs(*cond_n[:2]) - _validate(h.hugr) + validate(h.hugr) def test_if_else() -> None: @@ -63,7 +63,7 @@ def test_if_else() -> None: cond = else_.finish() h.set_outputs(cond) - _validate(h.hugr) + validate(h.hugr) def test_tail_loop() -> None: @@ -80,7 +80,7 @@ def build_tl(tl: TailLoop) -> None: build_tl(tl) h.set_outputs(tl) - _validate(h.hugr) + validate(h.hugr) # build then insert tl = TailLoop([], [tys.Qubit]) @@ -90,7 +90,7 @@ def build_tl(tl: TailLoop) -> None: (q,) = h.inputs() tl_n = h.insert_tail_loop(tl, q) h.set_outputs(tl_n) - _validate(h.hugr) + validate(h.hugr) def test_complex_tail_loop() -> None: @@ -119,6 +119,6 @@ def test_complex_tail_loop() -> None: # loop returns [qubit, int, bool] h.set_outputs(*tl[:3]) - _validate(h.hugr, True) + validate(h.hugr, True) # TODO rewrite with context managers diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 89e549108..a01b89d97 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -1,128 +1,16 @@ from __future__ import annotations -from dataclasses import dataclass, field -import subprocess -import os -import pathlib -from hugr._node_port import Node, Wire, _SubPort - -from hugr._hugr import Hugr -from hugr._dfg import Dfg, _ancestral_sibling -from hugr._ops import Custom, Command, NoConcreteFunc -import hugr._ops as ops -from hugr.serialization import SerialHugr -import hugr._tys as tys -import hugr._val as val -from hugr._function import Module +from hugr.node_port import Node, _SubPort + +from hugr.hugr import Hugr +from hugr.dfg import Dfg, _ancestral_sibling +from hugr.ops import NoConcreteFunc +import hugr.ops as ops +import hugr.tys as tys +import hugr.val as val +from hugr.function import Module import pytest -import json - -def int_t(width: int) -> tys.Opaque: - return tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.BoundedNatArg(n=width)], - bound=tys.TypeBound.Eq, - ) - - -INT_T = int_t(5) - - -@dataclass -class IntVal(val.ExtensionValue): - v: int - - def to_value(self) -> val.Extension: - return val.Extension("int", INT_T, self.v) - - -@dataclass -class LogicOps(Custom): - extension: tys.ExtensionId = "logic" - - -# TODO get from YAML -@dataclass -class NotDef(LogicOps): - num_out: int | None = 1 - op_name: str = "Not" - signature: tys.FunctionType = tys.FunctionType.endo([tys.Bool]) - - def __call__(self, a: Wire) -> Command: - return super().__call__(a) - - -Not = NotDef() - - -@dataclass -class QuantumOps(Custom): - extension: tys.ExtensionId = "tket2.quantum" - - -@dataclass -class OneQbGate(QuantumOps): - op_name: str - num_out: int | None = 1 - signature: tys.FunctionType = tys.FunctionType.endo([tys.Qubit]) - - def __call__(self, q: Wire) -> Command: - return super().__call__(q) - - -H = OneQbGate("H") - - -@dataclass -class MeasureDef(QuantumOps): - op_name: str = "Measure" - num_out: int | None = 2 - signature: tys.FunctionType = tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]) - - def __call__(self, q: Wire) -> Command: - return super().__call__(q) - - -Measure = MeasureDef() - - -@dataclass -class IntOps(Custom): - extension: tys.ExtensionId = "arithmetic.int" - - -ARG_5 = tys.BoundedNatArg(n=5) - - -@dataclass -class DivModDef(IntOps): - num_out: int | None = 2 - extension: tys.ExtensionId = "arithmetic.int" - op_name: str = "idivmod_u" - signature: tys.FunctionType = field( - default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) - ) - args: list[tys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) - - -DivMod = DivModDef() - - -def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): - workspace_dir = pathlib.Path(__file__).parent.parent.parent - # use the HUGR_BIN environment variable if set, otherwise use the debug build - bin_loc = os.environ.get("HUGR_BIN", str(workspace_dir / "target/debug/hugr")) - cmd = [bin_loc, "-"] - - if mermaid: - cmd.append("--mermaid") - serial = h.to_serial().to_json() - subprocess.run(cmd, check=True, input=serial.encode()) - - if roundtrip: - h2 = Hugr.from_serial(SerialHugr.load_json(json.loads(serial))) - assert serial == h2.to_serial().to_json() +from .conftest import Not, INT_T, IntVal, validate, DivMod def test_stable_indices(): @@ -168,7 +56,7 @@ def simple_id() -> Dfg: def test_simple_id(): - _validate(simple_id().hugr) + validate(simple_id().hugr) def test_multiport(): @@ -191,7 +79,7 @@ def test_multiport(): ] assert list(h.hugr.linked_ports(ou_n.inp(0))) == [in_n.out(0)] - _validate(h.hugr) + validate(h.hugr) def test_add_op(): @@ -200,7 +88,7 @@ def test_add_op(): nt = h.add_op(Not, a) h.set_outputs(nt) - _validate(h.hugr) + validate(h.hugr) def test_tuple(): @@ -211,7 +99,7 @@ def test_tuple(): a, b = h.add(ops.UnpackTuple(t)) h.set_outputs(a, b) - _validate(h.hugr) + validate(h.hugr) h1 = Dfg(*row) a, b = h1.inputs() @@ -227,7 +115,7 @@ def test_multi_out(): a, b = h.inputs() a, b = h.add(DivMod(a, b)) h.set_outputs(a, b) - _validate(h.hugr) + validate(h.hugr) def test_insert(): @@ -254,7 +142,7 @@ def test_insert_nested(): nested = h.insert_nested(h1, a) h.set_outputs(nested) assert len(h.hugr.children(nested)) == 3 - _validate(h.hugr) + validate(h.hugr) def test_build_nested(): @@ -271,7 +159,7 @@ def _nested_nop(dfg: Dfg): assert len(h.hugr.children(nested)) == 3 h.set_outputs(nested) - _validate(h.hugr) + validate(h.hugr) def test_build_inter_graph(): @@ -284,7 +172,7 @@ def test_build_inter_graph(): h.set_outputs(nested, b) - _validate(h.hugr) + validate(h.hugr) assert _SubPort(h.input_node.out(-1)) in h.hugr._links assert h.hugr.num_outgoing(h.input_node) == 2 # doesn't count state order @@ -315,7 +203,7 @@ def test_vals(val: val.Value): d = Dfg() d.set_outputs(d.load(val)) - _validate(d.hugr) + validate(d.hugr) @pytest.mark.parametrize("direct_call", [True, False]) @@ -348,7 +236,7 @@ def test_poly_function(direct_call: bool) -> None: f_main.set_outputs(call) - _validate(mod.hugr, True) + validate(mod.hugr, True) @pytest.mark.parametrize("direct_call", [True, False]) @@ -367,7 +255,7 @@ def test_mono_function(direct_call: bool) -> None: call = f_main.add(ops.CallIndirect(load, q)) f_main.set_outputs(call) - _validate(mod.hugr) + validate(mod.hugr) def test_higher_order() -> None: @@ -380,7 +268,7 @@ def test_higher_order() -> None: call = d.add(ops.CallIndirect(f_val, q))[0] d.set_outputs(call) - _validate(d.hugr) + validate(d.hugr) def test_lift() -> None: @@ -388,7 +276,7 @@ def test_lift() -> None: (q,) = d.inputs() lift = d.add(ops.Lift("X")(q)) d.set_outputs(lift) - _validate(d.hugr) + validate(d.hugr) def test_alias() -> None: @@ -396,4 +284,4 @@ def test_alias() -> None: _dfn = mod.add_alias_defn("my_int", INT_T) _dcl = mod.add_alias_decl("my_bool", tys.TypeBound.Eq) - _validate(mod.hugr) + validate(mod.hugr)