diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index 1cdffd8085..b15341c73c 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -17,9 +17,9 @@ from typing_extensions import Self from hugr.serialization.serial_hugr import SerialHugr -from hugr.serialization.ops import BaseOp, OpType as SerialOp -import hugr.serialization.ops as sops -from hugr.serialization.tys import Type +from hugr.serialization.ops import OpType as SerialOp +from hugr.serialization.tys import Type, FunctionType +from hugr._ops import Op, Input, Output, DFG from hugr.utils import BiMap @@ -43,6 +43,13 @@ class Wire(Protocol): def out_port(self) -> OutPort: ... +class Command(Protocol): + def op(self) -> Op: ... + def incoming(self) -> Iterable[Wire]: ... + def num_out(self) -> int | None: + return None + + @dataclass(frozen=True, eq=True, order=True) class OutPort(_Port, Wire): direction: ClassVar[Direction] = Direction.OUTGOING @@ -100,35 +107,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort: return self.out(offset) -class Op(Protocol): - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ... - - @classmethod - def from_serial(cls, serial: SerialOp) -> Self: ... - - -T = TypeVar("T", bound=BaseOp) - - -@dataclass() -class DummyOp(Op, Generic[T]): - _serial_op: T - - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - return SerialOp(root=self._serial_op.model_copy()) # type: ignore - - @classmethod - def from_serial(cls, serial: SerialOp) -> DummyOp: - return DummyOp(serial.root) - - -class Command(Protocol): - def op(self) -> Op: ... - def incoming(self) -> Iterable[Wire]: ... - def num_out(self) -> int | None: - return None - - @dataclass() class NodeData: op: Op @@ -138,10 +116,9 @@ class NodeData: # TODO children field? def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - o = self.op.to_serial(node, hugr) - o.root.parent = self.parent.idx if self.parent else node.idx + o = self.op.to_serial(node, self.parent if self.parent else node, hugr) - return o + return SerialOp(root=o) # type: ignore P = TypeVar("P", InPort, OutPort) @@ -371,7 +348,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr: hugr.root = Node(idx) parent = None serial_node.root.parent = -1 - hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent)) + hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent)) for (src_node, src_offset), (dst_node, dst_offset) in serial.edges: if src_offset is None or dst_offset is None: @@ -395,35 +372,25 @@ def __init__( ) -> None: input_types = list(input_types) output_types = list(output_types) - root_op = DummyOp(sops.DFG(parent=-1)) - root_op._serial_op.signature.input = input_types - root_op._serial_op.signature.output = output_types + root_op = DFG(FunctionType(input=input_types, output=output_types)) self.hugr = Hugr(root_op) self.root = self.hugr.root self.input_node = self.hugr.add_node( - DummyOp(sops.Input(parent=0, types=input_types)), - self.root, - len(input_types), - ) - self.output_node = self.hugr.add_node( - DummyOp(sops.Output(parent=0, types=output_types)), self.root + Input(input_types), self.root, len(input_types) ) + self.output_node = self.hugr.add_node(Output(output_types), self.root) @classmethod def endo(cls, types: Sequence[Type]) -> Dfg: return Dfg(types, types) - def _input_op(self) -> DummyOp[sops.Input]: + def _input_op(self) -> Input: dop = self.hugr[self.input_node].op - assert isinstance(dop, DummyOp) - assert isinstance(dop._serial_op, sops.Input) + assert isinstance(dop, Input) return dop def inputs(self) -> list[OutPort]: - return [ - self.input_node.out(i) - for i in range(len(self._input_op()._serial_op.types)) - ] + return [self.input_node.out(i) for i in range(len(self._input_op().types))] def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py new file mode 100644 index 0000000000..05956024c3 --- /dev/null +++ b/hugr-py/src/hugr/_ops.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Generic, Protocol, TypeVar, TYPE_CHECKING +from hugr.serialization.ops import BaseOp +import hugr.serialization.ops as sops +import hugr.serialization.tys as tys + +if TYPE_CHECKING: + from hugr._hugr import Hugr, Node + + +class Op(Protocol): + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... + + +T = TypeVar("T", bound=BaseOp) + + +@dataclass() +class SerWrap(Op, Generic[T]): + # catch all for serial ops that don't have a corresponding Op class + _serial_op: T + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: + root = self._serial_op.model_copy() + root.parent = parent.idx + return root + + +@dataclass() +class Input(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: + return sops.Input(parent=parent.idx, types=self.types) + + +@dataclass() +class Output(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: + return sops.Output(parent=parent.idx, types=self.types) + + +@dataclass() +class Custom(Op): + extension: tys.ExtensionId + op_name: str + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + description: str = "" + args: list[tys.TypeArg] = field(default_factory=list) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: + return sops.CustomOp( + parent=parent.idx, + extension=self.extension, + op_name=self.op_name, + signature=self.signature, + description=self.description, + args=self.args, + ) + + +@dataclass() +class MakeTuple(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: + return sops.MakeTuple( + parent=parent.idx, + tys=self.types, + ) + + +@dataclass() +class UnpackTuple(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: + return sops.UnpackTuple( + parent=parent.idx, + tys=self.types, + ) + + +@dataclass() +class DFG(Op): + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: + return sops.DFG( + parent=parent.idx, + signature=self.signature, + ) diff --git a/hugr-py/src/hugr/_types.py b/hugr-py/src/hugr/_types.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py deleted file mode 100644 index c645d06641..0000000000 --- a/hugr-py/src/hugr/ops.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Generic, TypeVar, TYPE_CHECKING -from hugr.serialization.ops import BaseOp, OpType as SerialOp -import hugr.serialization.ops as sops -import hugr.serialization.tys as tys - -if TYPE_CHECKING: - from hugr._hugr import Hugr, Node - - -class Op(ABC): - @abstractmethod - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: ... - - @classmethod - def from_serial(cls, serial: SerialOp) -> Op: - match serial.root: - case sops.Input(types=types): - return Input(types=types) - case sops.Output(types=types): - return Output(types=types) - case sops.CustomOp( - extension=extension, - op_name=op_name, - signature=signature, - description=description, - args=args, - ): - return Custom( - extension=extension, - op_name=op_name, - signature=signature, - description=description, - args=args, - ) - case sops.MakeTuple(tys=types): - return MakeTuple(types=types) - case sops.UnpackTuple(tys=types): - return UnpackTuple(types=types) - case sops.DFG(signature=signature): - return DFG(signature=signature) - return SerWrap(serial.root) - - -T = TypeVar("T", bound=BaseOp) - - -@dataclass() -class SerWrap(Op, Generic[T]): - # catch all for serial ops that don't have a corresponding Op class - _serial_op: T - - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: - root = self._serial_op.model_copy() - root.parent = parent.idx - return SerialOp(root=root) # type: ignore - - -@dataclass() -class Input(Op): - types: list[tys.Type] - - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: - return SerialOp(root=sops.Input(parent=parent.idx, types=self.types)) - - -@dataclass() -class Output(Op): - types: list[tys.Type] - - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: - return SerialOp(root=sops.Output(parent=parent.idx, types=self.types)) - - -@dataclass() -class Custom(Op): - extension: tys.ExtensionId - op_name: str - signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) - description: str = "" - args: list[tys.TypeArg] = field(default_factory=list) - - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: - return SerialOp( - root=sops.CustomOp( - parent=parent.idx, - extension=self.extension, - op_name=self.op_name, - signature=self.signature, - description=self.description, - args=self.args, - ) - ) - - -@dataclass() -class MakeTuple(Op): - types: list[tys.Type] - - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: - return SerialOp( - root=sops.MakeTuple( - parent=parent.idx, - tys=self.types, - ) - ) - - -@dataclass() -class UnpackTuple(Op): - types: list[tys.Type] - - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: - return SerialOp( - root=sops.UnpackTuple( - parent=parent.idx, - tys=self.types, - ) - ) - - -@dataclass() -class DFG(Op): - signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) - - def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> SerialOp: - return SerialOp( - root=sops.DFG( - parent=parent.idx, - signature=self.signature, - ) - ) diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index af127bfbde..4621c875b3 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -1,3 +1,4 @@ +from __future__ import annotations import inspect import sys from abc import ABC @@ -41,6 +42,10 @@ def display_name(self) -> str: """Name of the op for visualisation""" return self.__class__.__name__ + def deserialize(self) -> ops.Op: + """Deserializes the model into the corresponding Op.""" + return ops.SerWrap(self) + # ---------------------------------------------------------- # --------------- Module level operations ------------------ @@ -209,6 +214,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(in_types) == 0 self.types = list(out_types) + def deserialize(self) -> ops.Input: + return ops.Input(types=self.types) + class Output(DataflowOp): """An output node. The inputs are the outputs of the function.""" @@ -220,6 +228,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(out_types) == 0 self.types = list(in_types) + def deserialize(self) -> ops.Output: + return ops.Output(types=self.types) + class Call(DataflowOp): """ @@ -292,6 +303,9 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) + def deserialize(self) -> ops.DFG: + return ops.DFG(self.signature) + # ------------------------------------------------ # --------------- ControlFlowOp ------------------ @@ -388,6 +402,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: def display_name(self) -> str: return self.op_name + def deserialize(self) -> ops.Custom: + return ops.Custom( + extension=self.extension, + op_name=self.op_name, + signature=self.signature, + args=self.args, + ) + model_config = ConfigDict( # Needed to avoid random '\n's in the pydantic description json_schema_extra={ @@ -424,6 +446,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: in_types = [] self.tys = list(in_types) + def deserialize(self) -> ops.MakeTuple: + return ops.MakeTuple(self.tys) + class UnpackTuple(DataflowOp): """An operation that packs all its inputs into a tuple.""" @@ -434,6 +459,9 @@ class UnpackTuple(DataflowOp): def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(out_types) + def deserialize(self) -> ops.UnpackTuple: + return ops.UnpackTuple(self.tys) + class Tag(DataflowOp): """An operation that creates a tagged sum value from one of its variants.""" @@ -529,3 +557,6 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): ) tys_model_rebuild(dict(classes)) + +# +import hugr._ops as ops # noqa: E402 # needed to avoid circular imports diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 522da06f6c..3f1971083d 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -3,10 +3,11 @@ import subprocess import os import pathlib -from hugr._hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op +from hugr._hugr import Dfg, Hugr, Node, Command, Wire +from hugr._ops import Op, Custom +import hugr._ops as ops from hugr.serialization import SerialHugr import hugr.serialization.tys as stys -import hugr.serialization.ops as sops import pytest import json @@ -22,14 +23,11 @@ ) ) -NOT_OP = DummyOp( - # TODO get from YAML - sops.CustomOp( - parent=-1, - extension="logic", - op_name="Not", - signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), - ) +# TODO get from YAML +NOT_OP = Custom( + extension="logic", + op_name="Not", + signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), ) @@ -59,14 +57,11 @@ def num_out(self) -> int | None: return 2 def op(self) -> Op: - return DummyOp( - sops.CustomOp( - parent=-1, - extension="arithmetic.int", - op_name="idivmod_u", - signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), - args=[ARG_5, ARG_5], - ) + return Custom( + extension="arithmetic.int", + op_name="idivmod_u", + signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), + args=[ARG_5, ARG_5], ) @@ -82,7 +77,7 @@ def num_out(self) -> int | None: return 1 def op(self) -> Op: - return DummyOp(sops.MakeTuple(parent=-1, tys=self.types)) + return ops.MakeTuple(self.types) @dataclass @@ -97,7 +92,7 @@ def num_out(self) -> int | None: return len(self.types) def op(self) -> Op: - return DummyOp(sops.UnpackTuple(parent=-1, tys=self.types)) + return ops.UnpackTuple(self.types) def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): @@ -117,7 +112,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): def test_stable_indices(): - h = Hugr(DummyOp(sops.DFG(parent=-1))) + h = Hugr(ops.DFG()) nodes = [h.add_node(NOT_OP) for _ in range(3)] assert len(h) == 4 @@ -201,8 +196,8 @@ def test_tuple(): h1 = Dfg.endo(row) a, b = h1.inputs() - mt = h1.add_op(DummyOp(sops.MakeTuple(parent=-1, tys=row)), a, b) - a, b = h1.add_op(DummyOp(sops.UnpackTuple(parent=-1, tys=row)), mt)[0, 1] + mt = h1.add_op(ops.MakeTuple(row), a, b) + a, b = h1.add_op(ops.UnpackTuple(row), mt)[0, 1] h1.set_outputs(a, b) assert h.hugr.to_serial() == h1.hugr.to_serial() @@ -224,7 +219,7 @@ def test_insert(): assert len(h1.hugr) == 4 - new_h = Hugr(DummyOp(sops.DFG(parent=-1))) + new_h = Hugr(ops.DFG()) mapping = h1.hugr.insert_hugr(new_h, h1.hugr.root) assert mapping == {new_h.root: Node(4)}