From 1ffd7458647c4ee170c9f2c706ed6b6b55b8da04 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 30 May 2024 13:41:01 +0100 Subject: [PATCH 1/5] feat: builder ops separate from serialised ops no more parent=-1 --- hugr-py/src/hugr/_hugr.py | 71 ++++++-------------- hugr-py/src/hugr/_ops.py | 96 +++++++++++++++++++++++++++ hugr-py/src/hugr/_types.py | 0 hugr-py/src/hugr/serialization/ops.py | 31 +++++++++ hugr-py/tests/test_hugr_build.py | 43 ++++++------ 5 files changed, 165 insertions(+), 76 deletions(-) create mode 100644 hugr-py/src/hugr/_ops.py create mode 100644 hugr-py/src/hugr/_types.py diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index d7552113f..ad79eb201 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -17,9 +17,9 @@ from typing_extensions import Self from hugr.serialization.serial_hugr import SerialHugr -from hugr.serialization.ops import BaseOp, OpType as SerialOp -import hugr.serialization.ops as sops -from hugr.serialization.tys import Type +from hugr.serialization.ops import OpType as SerialOp +from hugr.serialization.tys import Type, FunctionType +from hugr._ops import Op, Input, Output, DFG from hugr.utils import BiMap @@ -44,6 +44,13 @@ class Wire(Protocol): def out_port(self) -> OutPort: ... +class Command(Protocol): + def op(self) -> Op: ... + def incoming(self) -> Iterable[Wire]: ... + def num_out(self) -> int | None: + return None + + @dataclass(frozen=True, eq=True, order=True) class OutPort(_Port, Wire): direction: ClassVar[Direction] = Direction.OUTGOING @@ -101,35 +108,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort: return self.out(offset) -class Op(Protocol): - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ... - - @classmethod - def from_serial(cls, serial: SerialOp) -> Self: ... - - -T = TypeVar("T", bound=BaseOp) - - -@dataclass() -class DummyOp(Op, Generic[T]): - _serial_op: T - - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - return SerialOp(root=self._serial_op.model_copy()) # type: ignore - - @classmethod - def from_serial(cls, serial: SerialOp) -> DummyOp: - return DummyOp(serial.root) - - -class Command(Protocol): - def op(self) -> Op: ... - def incoming(self) -> Iterable[Wire]: ... - def num_out(self) -> int | None: - return None - - @dataclass() class NodeData: op: Op @@ -139,10 +117,9 @@ class NodeData: # TODO children field? def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - o = self.op.to_serial(node, hugr) - o.root.parent = self.parent.idx if self.parent else node.idx + o = self.op.to_serial(node, self.parent if self.parent else node, hugr) - return o + return SerialOp(root=o) # type: ignore P = TypeVar("P", InPort, OutPort) @@ -372,7 +349,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr: hugr.root = Node(idx) parent = None serial_node.root.parent = -1 - hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent)) + hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent)) for (src_node, src_offset), (dst_node, dst_offset) in serial.edges: if src_offset is None or dst_offset is None: @@ -396,35 +373,25 @@ def __init__( ) -> None: input_types = list(input_types) output_types = list(output_types) - root_op = DummyOp(sops.DFG(parent=-1)) - root_op._serial_op.signature.input = input_types - root_op._serial_op.signature.output = output_types + root_op = DFG(FunctionType(input=input_types, output=output_types)) self.hugr = Hugr(root_op) self.root = self.hugr.root self.input_node = self.hugr.add_node( - DummyOp(sops.Input(parent=0, types=input_types)), - self.root, - len(input_types), - ) - self.output_node = self.hugr.add_node( - DummyOp(sops.Output(parent=0, types=output_types)), self.root + Input(input_types), self.root, len(input_types) ) + self.output_node = self.hugr.add_node(Output(output_types), self.root) @classmethod def endo(cls, types: Sequence[Type]) -> Dfg: return Dfg(types, types) - def _input_op(self) -> DummyOp[sops.Input]: + def _input_op(self) -> Input: dop = self.hugr[self.input_node].op - assert isinstance(dop, DummyOp) - assert isinstance(dop._serial_op, sops.Input) + assert isinstance(dop, Input) return dop def inputs(self) -> list[OutPort]: - return [ - self.input_node.out(i) - for i in range(len(self._input_op()._serial_op.types)) - ] + return [self.input_node.out(i) for i in range(len(self._input_op().types))] def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py new file mode 100644 index 000000000..d6d8dc691 --- /dev/null +++ b/hugr-py/src/hugr/_ops.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Generic, Protocol, TypeVar, TYPE_CHECKING +from hugr.serialization.ops import BaseOp +import hugr.serialization.ops as sops +import hugr.serialization.tys as tys + +if TYPE_CHECKING: + from hugr._hugr import Hugr, Node + + +class Op(Protocol): + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... + + +T = TypeVar("T", bound=BaseOp) + + +@dataclass() +class SerWrap(Op, Generic[T]): + # catch all for serial ops that don't have a corresponding Op class + _serial_op: T + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T: + root = self._serial_op.model_copy() + root.parent = parent.idx + return root + + +@dataclass() +class Input(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: + return sops.Input(parent=parent.idx, types=self.types) + + +@dataclass() +class Output(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: + return sops.Output(parent=parent.idx, types=self.types) + + +@dataclass() +class Custom(Op): + extension: tys.ExtensionId + op_name: str + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + description: str = "" + args: list[tys.TypeArg] = field(default_factory=list) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: + return sops.CustomOp( + parent=parent.idx, + extension=self.extension, + op_name=self.op_name, + signature=self.signature, + description=self.description, + args=self.args, + ) + + +@dataclass() +class MakeTuple(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: + return sops.MakeTuple( + parent=parent.idx, + tys=self.types, + ) + + +@dataclass() +class UnpackTuple(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: + return sops.UnpackTuple( + parent=parent.idx, + tys=self.types, + ) + + +@dataclass() +class DFG(Op): + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: + return sops.DFG( + parent=parent.idx, + signature=self.signature, + ) diff --git a/hugr-py/src/hugr/_types.py b/hugr-py/src/hugr/_types.py new file mode 100644 index 000000000..e69de29bb diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index af127bfbd..4621c875b 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -1,3 +1,4 @@ +from __future__ import annotations import inspect import sys from abc import ABC @@ -41,6 +42,10 @@ def display_name(self) -> str: """Name of the op for visualisation""" return self.__class__.__name__ + def deserialize(self) -> ops.Op: + """Deserializes the model into the corresponding Op.""" + return ops.SerWrap(self) + # ---------------------------------------------------------- # --------------- Module level operations ------------------ @@ -209,6 +214,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(in_types) == 0 self.types = list(out_types) + def deserialize(self) -> ops.Input: + return ops.Input(types=self.types) + class Output(DataflowOp): """An output node. The inputs are the outputs of the function.""" @@ -220,6 +228,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(out_types) == 0 self.types = list(in_types) + def deserialize(self) -> ops.Output: + return ops.Output(types=self.types) + class Call(DataflowOp): """ @@ -292,6 +303,9 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) + def deserialize(self) -> ops.DFG: + return ops.DFG(self.signature) + # ------------------------------------------------ # --------------- ControlFlowOp ------------------ @@ -388,6 +402,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: def display_name(self) -> str: return self.op_name + def deserialize(self) -> ops.Custom: + return ops.Custom( + extension=self.extension, + op_name=self.op_name, + signature=self.signature, + args=self.args, + ) + model_config = ConfigDict( # Needed to avoid random '\n's in the pydantic description json_schema_extra={ @@ -424,6 +446,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: in_types = [] self.tys = list(in_types) + def deserialize(self) -> ops.MakeTuple: + return ops.MakeTuple(self.tys) + class UnpackTuple(DataflowOp): """An operation that packs all its inputs into a tuple.""" @@ -434,6 +459,9 @@ class UnpackTuple(DataflowOp): def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(out_types) + def deserialize(self) -> ops.UnpackTuple: + return ops.UnpackTuple(self.tys) + class Tag(DataflowOp): """An operation that creates a tagged sum value from one of its variants.""" @@ -529,3 +557,6 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): ) tys_model_rebuild(dict(classes)) + +# +import hugr._ops as ops # noqa: E402 # needed to avoid circular imports diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 522da06f6..3f1971083 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -3,10 +3,11 @@ import subprocess import os import pathlib -from hugr._hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op +from hugr._hugr import Dfg, Hugr, Node, Command, Wire +from hugr._ops import Op, Custom +import hugr._ops as ops from hugr.serialization import SerialHugr import hugr.serialization.tys as stys -import hugr.serialization.ops as sops import pytest import json @@ -22,14 +23,11 @@ ) ) -NOT_OP = DummyOp( - # TODO get from YAML - sops.CustomOp( - parent=-1, - extension="logic", - op_name="Not", - signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), - ) +# TODO get from YAML +NOT_OP = Custom( + extension="logic", + op_name="Not", + signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), ) @@ -59,14 +57,11 @@ def num_out(self) -> int | None: return 2 def op(self) -> Op: - return DummyOp( - sops.CustomOp( - parent=-1, - extension="arithmetic.int", - op_name="idivmod_u", - signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), - args=[ARG_5, ARG_5], - ) + return Custom( + extension="arithmetic.int", + op_name="idivmod_u", + signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), + args=[ARG_5, ARG_5], ) @@ -82,7 +77,7 @@ def num_out(self) -> int | None: return 1 def op(self) -> Op: - return DummyOp(sops.MakeTuple(parent=-1, tys=self.types)) + return ops.MakeTuple(self.types) @dataclass @@ -97,7 +92,7 @@ def num_out(self) -> int | None: return len(self.types) def op(self) -> Op: - return DummyOp(sops.UnpackTuple(parent=-1, tys=self.types)) + return ops.UnpackTuple(self.types) def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): @@ -117,7 +112,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): def test_stable_indices(): - h = Hugr(DummyOp(sops.DFG(parent=-1))) + h = Hugr(ops.DFG()) nodes = [h.add_node(NOT_OP) for _ in range(3)] assert len(h) == 4 @@ -201,8 +196,8 @@ def test_tuple(): h1 = Dfg.endo(row) a, b = h1.inputs() - mt = h1.add_op(DummyOp(sops.MakeTuple(parent=-1, tys=row)), a, b) - a, b = h1.add_op(DummyOp(sops.UnpackTuple(parent=-1, tys=row)), mt)[0, 1] + mt = h1.add_op(ops.MakeTuple(row), a, b) + a, b = h1.add_op(ops.UnpackTuple(row), mt)[0, 1] h1.set_outputs(a, b) assert h.hugr.to_serial() == h1.hugr.to_serial() @@ -224,7 +219,7 @@ def test_insert(): assert len(h1.hugr) == 4 - new_h = Hugr(DummyOp(sops.DFG(parent=-1))) + new_h = Hugr(ops.DFG()) mapping = h1.hugr.insert_hugr(new_h, h1.hugr.root) assert mapping == {new_h.root: Node(4)} From 3e699fce265834ae85de1cf9e892a9010be52109 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 31 May 2024 17:20:02 +0100 Subject: [PATCH 2/5] refactor: commands -> call on ops --- hugr-py/src/hugr/_hugr.py | 13 ++--- hugr-py/src/hugr/_ops.py | 31 ++++++++++- hugr-py/tests/test_hugr_build.py | 93 +++++++++++--------------------- 3 files changed, 63 insertions(+), 74 deletions(-) diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index ad79eb201..ac8766aea 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -19,7 +19,7 @@ from hugr.serialization.serial_hugr import SerialHugr from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.tys import Type, FunctionType -from hugr._ops import Op, Input, Output, DFG +from hugr._ops import Op, Input, Output, DFG, Command from hugr.utils import BiMap @@ -44,13 +44,6 @@ class Wire(Protocol): def out_port(self) -> OutPort: ... -class Command(Protocol): - def op(self) -> Op: ... - def incoming(self) -> Iterable[Wire]: ... - def num_out(self) -> int | None: - return None - - @dataclass(frozen=True, eq=True, order=True) class OutPort(_Port, Wire): direction: ClassVar[Direction] = Direction.OUTGOING @@ -119,7 +112,7 @@ class NodeData: def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: o = self.op.to_serial(node, self.parent if self.parent else node, hugr) - return SerialOp(root=o) # type: ignore + return SerialOp(root=o) # type: ignore[arg-type] P = TypeVar("P", InPort, OutPort) @@ -399,7 +392,7 @@ def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: return new_n def add(self, com: Command) -> Node: - return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out()) + return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out) def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: mapping = self.hugr.insert_hugr(dfg.hugr, self.root) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index d6d8dc691..37b934bc0 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -7,12 +7,25 @@ import hugr.serialization.tys as tys if TYPE_CHECKING: - from hugr._hugr import Hugr, Node + from hugr._hugr import Hugr, Node, Wire class Op(Protocol): + @property + def num_out(self) -> int | None: + return None + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... + def __call__(self, *args) -> Command: + return Command(self, list(args)) + + +@dataclass(frozen=True) +class Command: + op: Op + incoming: list[Wire] + T = TypeVar("T", bound=BaseOp) @@ -35,6 +48,9 @@ class Input(Op): def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: return sops.Input(parent=parent.idx, types=self.types) + def __call__(self) -> Command: + return super().__call__() + @dataclass() class Output(Op): @@ -46,10 +62,10 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: @dataclass() class Custom(Op): - extension: tys.ExtensionId op_name: str signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) description: str = "" + extension: tys.ExtensionId = "" args: list[tys.TypeArg] = field(default_factory=list) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: @@ -66,6 +82,7 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: @dataclass() class MakeTuple(Op): types: list[tys.Type] + num_out: int | None = 1 def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: return sops.MakeTuple( @@ -73,17 +90,27 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: tys=self.types, ) + def __call__(self, *elements: Wire) -> Command: + return super().__call__(*elements) + @dataclass() class UnpackTuple(Op): types: list[tys.Type] + @property + def num_out(self) -> int | None: + return len(self.types) + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: return sops.UnpackTuple( parent=parent.idx, tys=self.types, ) + def __call__(self, tuple_: Wire) -> Command: + return super().__call__(tuple_) + @dataclass() class DFG(Op): diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 3f1971083..5c78a145b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -1,10 +1,10 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field import subprocess import os import pathlib -from hugr._hugr import Dfg, Hugr, Node, Command, Wire -from hugr._ops import Op, Custom +from hugr._hugr import Dfg, Hugr, Node, Wire +from hugr._ops import Custom, Command import hugr._ops as ops from hugr.serialization import SerialHugr import hugr.serialization.tys as stys @@ -23,76 +23,45 @@ ) ) -# TODO get from YAML -NOT_OP = Custom( - extension="logic", - op_name="Not", - signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), -) - @dataclass -class Not(Command): - a: Wire - - def incoming(self) -> list[Wire]: - return [self.a] - - def num_out(self) -> int | None: - return 1 - - def op(self) -> Op: - return NOT_OP +class LogicOps(Custom): + extension: stys.ExtensionId = "logic" +# TODO get from YAML @dataclass -class DivMod(Command): - a: Wire - b: Wire +class NotDef(LogicOps): + num_out: int | None = 1 + op_name: str = "Not" + signature: stys.FunctionType = field( + default_factory=lambda: stys.FunctionType(input=[BOOL_T], output=[BOOL_T]) + ) - def incoming(self) -> list[Wire]: - return [self.a, self.b] + def __call__(self, a: Wire) -> Command: + return super().__call__(a) - def num_out(self) -> int | None: - return 2 - def op(self) -> Op: - return Custom( - extension="arithmetic.int", - op_name="idivmod_u", - signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), - args=[ARG_5, ARG_5], - ) +Not = NotDef() @dataclass -class MakeTuple(Command): - types: list[stys.Type] - wires: list[Wire] - - def incoming(self) -> list[Wire]: - return self.wires - - def num_out(self) -> int | None: - return 1 - - def op(self) -> Op: - return ops.MakeTuple(self.types) +class IntOps(Custom): + extension: stys.ExtensionId = "arithmetic.int" @dataclass -class UnpackTuple(Command): - types: list[stys.Type] - wire: Wire - - def incoming(self) -> list[Wire]: - return [self.wire] +class DivModDef(IntOps): + num_out: int | None = 2 + extension: stys.ExtensionId = "arithmetic.int" + op_name: str = "idivmod_u" + signature: stys.FunctionType = field( + default_factory=lambda: stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) + ) + args: list[stys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) - def num_out(self) -> int | None: - return len(self.types) - def op(self) -> Op: - return ops.UnpackTuple(self.types) +DivMod = DivModDef() def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): @@ -114,7 +83,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): def test_stable_indices(): h = Hugr(ops.DFG()) - nodes = [h.add_node(NOT_OP) for _ in range(3)] + nodes = [h.add_node(Not) for _ in range(3)] assert len(h) == 4 h.add_link(nodes[0].out(0), nodes[1].inp(0)) @@ -137,7 +106,7 @@ def test_stable_indices(): with pytest.raises(KeyError): _ = h[Node(46)] - new_n = h.add_node(NOT_OP) + new_n = h.add_node(Not) assert new_n == nodes[1] assert len(h) == 4 @@ -178,7 +147,7 @@ def test_multiport(): def test_add_op(): h = Dfg.endo([BOOL_T]) (a,) = h.inputs() - nt = h.add_op(NOT_OP, a) + nt = h.add_op(Not, a) h.set_outputs(nt) _validate(h.hugr) @@ -188,8 +157,8 @@ def test_tuple(): row = [BOOL_T, QB_T] h = Dfg.endo(row) a, b = h.inputs() - t = h.add(MakeTuple(row, [a, b])) - a, b = h.add(UnpackTuple(row, t)) + t = h.add(ops.MakeTuple(row)(a, b)) + a, b = h.add(ops.UnpackTuple(row)(t)) h.set_outputs(a, b) _validate(h.hugr) From d180047da98452538a748fb2bb037e02c62d5bd4 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 4 Jun 2024 16:35:23 +0100 Subject: [PATCH 3/5] fixup! feat: builder ops separate from serialised ops --- hugr-py/src/hugr/_types.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 hugr-py/src/hugr/_types.py diff --git a/hugr-py/src/hugr/_types.py b/hugr-py/src/hugr/_types.py deleted file mode 100644 index e69de29bb..000000000 From ebfd51bd03161a04a08d0eb1540cf2bb28a93566 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 4 Jun 2024 17:13:24 +0100 Subject: [PATCH 4/5] impl `num_out` for more ops --- hugr-py/src/hugr/_ops.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 37b934bc0..693c8ef31 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -45,6 +45,10 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T: class Input(Op): types: list[tys.Type] + @property + def num_out(self) -> int | None: + return len(self.types) + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: return sops.Input(parent=parent.idx, types=self.types) @@ -68,6 +72,10 @@ class Custom(Op): extension: tys.ExtensionId = "" args: list[tys.TypeArg] = field(default_factory=list) + @property + def num_out(self) -> int | None: + return len(self.signature.output) + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: return sops.CustomOp( parent=parent.idx, @@ -116,6 +124,10 @@ def __call__(self, tuple_: Wire) -> Command: class DFG(Op): signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + @property + def num_out(self) -> int | None: + return len(self.signature.output) + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: return sops.DFG( parent=parent.idx, From ac56552479a2cfc98e65dcef790360d2b7b865be Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 4 Jun 2024 17:15:32 +0100 Subject: [PATCH 5/5] refactor: use _ops in circular import --- hugr-py/src/hugr/serialization/ops.py | 31 +++++++++++++-------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index 4621c875b..ed3b74086 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -42,9 +42,9 @@ def display_name(self) -> str: """Name of the op for visualisation""" return self.__class__.__name__ - def deserialize(self) -> ops.Op: + def deserialize(self) -> _ops.Op: """Deserializes the model into the corresponding Op.""" - return ops.SerWrap(self) + return _ops.SerWrap(self) # ---------------------------------------------------------- @@ -214,8 +214,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(in_types) == 0 self.types = list(out_types) - def deserialize(self) -> ops.Input: - return ops.Input(types=self.types) + def deserialize(self) -> _ops.Input: + return _ops.Input(types=self.types) class Output(DataflowOp): @@ -228,8 +228,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(out_types) == 0 self.types = list(in_types) - def deserialize(self) -> ops.Output: - return ops.Output(types=self.types) + def deserialize(self) -> _ops.Output: + return _ops.Output(types=self.types) class Call(DataflowOp): @@ -303,8 +303,8 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) - def deserialize(self) -> ops.DFG: - return ops.DFG(self.signature) + def deserialize(self) -> _ops.DFG: + return _ops.DFG(self.signature) # ------------------------------------------------ @@ -402,8 +402,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: def display_name(self) -> str: return self.op_name - def deserialize(self) -> ops.Custom: - return ops.Custom( + def deserialize(self) -> _ops.Custom: + return _ops.Custom( extension=self.extension, op_name=self.op_name, signature=self.signature, @@ -446,8 +446,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: in_types = [] self.tys = list(in_types) - def deserialize(self) -> ops.MakeTuple: - return ops.MakeTuple(self.tys) + def deserialize(self) -> _ops.MakeTuple: + return _ops.MakeTuple(self.tys) class UnpackTuple(DataflowOp): @@ -459,8 +459,8 @@ class UnpackTuple(DataflowOp): def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(out_types) - def deserialize(self) -> ops.UnpackTuple: - return ops.UnpackTuple(self.tys) + def deserialize(self) -> _ops.UnpackTuple: + return _ops.UnpackTuple(self.tys) class Tag(DataflowOp): @@ -558,5 +558,4 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): tys_model_rebuild(dict(classes)) -# -import hugr._ops as ops # noqa: E402 # needed to avoid circular imports +from hugr import _ops # noqa: E402 # needed to avoid circular imports