diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index d7552113f..ac8766aea 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -17,9 +17,9 @@ from typing_extensions import Self from hugr.serialization.serial_hugr import SerialHugr -from hugr.serialization.ops import BaseOp, OpType as SerialOp -import hugr.serialization.ops as sops -from hugr.serialization.tys import Type +from hugr.serialization.ops import OpType as SerialOp +from hugr.serialization.tys import Type, FunctionType +from hugr._ops import Op, Input, Output, DFG, Command from hugr.utils import BiMap @@ -101,35 +101,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort: return self.out(offset) -class Op(Protocol): - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ... - - @classmethod - def from_serial(cls, serial: SerialOp) -> Self: ... - - -T = TypeVar("T", bound=BaseOp) - - -@dataclass() -class DummyOp(Op, Generic[T]): - _serial_op: T - - def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - return SerialOp(root=self._serial_op.model_copy()) # type: ignore - - @classmethod - def from_serial(cls, serial: SerialOp) -> DummyOp: - return DummyOp(serial.root) - - -class Command(Protocol): - def op(self) -> Op: ... - def incoming(self) -> Iterable[Wire]: ... - def num_out(self) -> int | None: - return None - - @dataclass() class NodeData: op: Op @@ -139,10 +110,9 @@ class NodeData: # TODO children field? def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: - o = self.op.to_serial(node, hugr) - o.root.parent = self.parent.idx if self.parent else node.idx + o = self.op.to_serial(node, self.parent if self.parent else node, hugr) - return o + return SerialOp(root=o) # type: ignore[arg-type] P = TypeVar("P", InPort, OutPort) @@ -372,7 +342,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr: hugr.root = Node(idx) parent = None serial_node.root.parent = -1 - hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent)) + hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent)) for (src_node, src_offset), (dst_node, dst_offset) in serial.edges: if src_offset is None or dst_offset is None: @@ -396,35 +366,25 @@ def __init__( ) -> None: input_types = list(input_types) output_types = list(output_types) - root_op = DummyOp(sops.DFG(parent=-1)) - root_op._serial_op.signature.input = input_types - root_op._serial_op.signature.output = output_types + root_op = DFG(FunctionType(input=input_types, output=output_types)) self.hugr = Hugr(root_op) self.root = self.hugr.root self.input_node = self.hugr.add_node( - DummyOp(sops.Input(parent=0, types=input_types)), - self.root, - len(input_types), - ) - self.output_node = self.hugr.add_node( - DummyOp(sops.Output(parent=0, types=output_types)), self.root + Input(input_types), self.root, len(input_types) ) + self.output_node = self.hugr.add_node(Output(output_types), self.root) @classmethod def endo(cls, types: Sequence[Type]) -> Dfg: return Dfg(types, types) - def _input_op(self) -> DummyOp[sops.Input]: + def _input_op(self) -> Input: dop = self.hugr[self.input_node].op - assert isinstance(dop, DummyOp) - assert isinstance(dop._serial_op, sops.Input) + assert isinstance(dop, Input) return dop def inputs(self) -> list[OutPort]: - return [ - self.input_node.out(i) - for i in range(len(self._input_op()._serial_op.types)) - ] + return [self.input_node.out(i) for i in range(len(self._input_op().types))] def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: new_n = self.hugr.add_node(op, self.root, num_outs=num_outs) @@ -432,7 +392,7 @@ def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: return new_n def add(self, com: Command) -> Node: - return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out()) + return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out) def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: mapping = self.hugr.insert_hugr(dfg.hugr, self.root) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py new file mode 100644 index 000000000..693c8ef31 --- /dev/null +++ b/hugr-py/src/hugr/_ops.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Generic, Protocol, TypeVar, TYPE_CHECKING +from hugr.serialization.ops import BaseOp +import hugr.serialization.ops as sops +import hugr.serialization.tys as tys + +if TYPE_CHECKING: + from hugr._hugr import Hugr, Node, Wire + + +class Op(Protocol): + @property + def num_out(self) -> int | None: + return None + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... + + def __call__(self, *args) -> Command: + return Command(self, list(args)) + + +@dataclass(frozen=True) +class Command: + op: Op + incoming: list[Wire] + + +T = TypeVar("T", bound=BaseOp) + + +@dataclass() +class SerWrap(Op, Generic[T]): + # catch all for serial ops that don't have a corresponding Op class + _serial_op: T + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T: + root = self._serial_op.model_copy() + root.parent = parent.idx + return root + + +@dataclass() +class Input(Op): + types: list[tys.Type] + + @property + def num_out(self) -> int | None: + return len(self.types) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: + return sops.Input(parent=parent.idx, types=self.types) + + def __call__(self) -> Command: + return super().__call__() + + +@dataclass() +class Output(Op): + types: list[tys.Type] + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: + return sops.Output(parent=parent.idx, types=self.types) + + +@dataclass() +class Custom(Op): + op_name: str + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + description: str = "" + extension: tys.ExtensionId = "" + args: list[tys.TypeArg] = field(default_factory=list) + + @property + def num_out(self) -> int | None: + return len(self.signature.output) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: + return sops.CustomOp( + parent=parent.idx, + extension=self.extension, + op_name=self.op_name, + signature=self.signature, + description=self.description, + args=self.args, + ) + + +@dataclass() +class MakeTuple(Op): + types: list[tys.Type] + num_out: int | None = 1 + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: + return sops.MakeTuple( + parent=parent.idx, + tys=self.types, + ) + + def __call__(self, *elements: Wire) -> Command: + return super().__call__(*elements) + + +@dataclass() +class UnpackTuple(Op): + types: list[tys.Type] + + @property + def num_out(self) -> int | None: + return len(self.types) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: + return sops.UnpackTuple( + parent=parent.idx, + tys=self.types, + ) + + def __call__(self, tuple_: Wire) -> Command: + return super().__call__(tuple_) + + +@dataclass() +class DFG(Op): + signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) + + @property + def num_out(self) -> int | None: + return len(self.signature.output) + + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: + return sops.DFG( + parent=parent.idx, + signature=self.signature, + ) diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index af127bfbd..ed3b74086 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -1,3 +1,4 @@ +from __future__ import annotations import inspect import sys from abc import ABC @@ -41,6 +42,10 @@ def display_name(self) -> str: """Name of the op for visualisation""" return self.__class__.__name__ + def deserialize(self) -> _ops.Op: + """Deserializes the model into the corresponding Op.""" + return _ops.SerWrap(self) + # ---------------------------------------------------------- # --------------- Module level operations ------------------ @@ -209,6 +214,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(in_types) == 0 self.types = list(out_types) + def deserialize(self) -> _ops.Input: + return _ops.Input(types=self.types) + class Output(DataflowOp): """An output node. The inputs are the outputs of the function.""" @@ -220,6 +228,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert len(out_types) == 0 self.types = list(in_types) + def deserialize(self) -> _ops.Output: + return _ops.Output(types=self.types) + class Call(DataflowOp): """ @@ -292,6 +303,9 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([]) ) + def deserialize(self) -> _ops.DFG: + return _ops.DFG(self.signature) + # ------------------------------------------------ # --------------- ControlFlowOp ------------------ @@ -388,6 +402,14 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: def display_name(self) -> str: return self.op_name + def deserialize(self) -> _ops.Custom: + return _ops.Custom( + extension=self.extension, + op_name=self.op_name, + signature=self.signature, + args=self.args, + ) + model_config = ConfigDict( # Needed to avoid random '\n's in the pydantic description json_schema_extra={ @@ -424,6 +446,9 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: in_types = [] self.tys = list(in_types) + def deserialize(self) -> _ops.MakeTuple: + return _ops.MakeTuple(self.tys) + class UnpackTuple(DataflowOp): """An operation that packs all its inputs into a tuple.""" @@ -434,6 +459,9 @@ class UnpackTuple(DataflowOp): def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(out_types) + def deserialize(self) -> _ops.UnpackTuple: + return _ops.UnpackTuple(self.tys) + class Tag(DataflowOp): """An operation that creates a tagged sum value from one of its variants.""" @@ -529,3 +557,5 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): ) tys_model_rebuild(dict(classes)) + +from hugr import _ops # noqa: E402 # needed to avoid circular imports diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 522da06f6..5c78a145b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -1,12 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field import subprocess import os import pathlib -from hugr._hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op +from hugr._hugr import Dfg, Hugr, Node, Wire +from hugr._ops import Custom, Command +import hugr._ops as ops from hugr.serialization import SerialHugr import hugr.serialization.tys as stys -import hugr.serialization.ops as sops import pytest import json @@ -22,82 +23,45 @@ ) ) -NOT_OP = DummyOp( - # TODO get from YAML - sops.CustomOp( - parent=-1, - extension="logic", - op_name="Not", - signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), - ) -) - @dataclass -class Not(Command): - a: Wire - - def incoming(self) -> list[Wire]: - return [self.a] - - def num_out(self) -> int | None: - return 1 - - def op(self) -> Op: - return NOT_OP +class LogicOps(Custom): + extension: stys.ExtensionId = "logic" +# TODO get from YAML @dataclass -class DivMod(Command): - a: Wire - b: Wire +class NotDef(LogicOps): + num_out: int | None = 1 + op_name: str = "Not" + signature: stys.FunctionType = field( + default_factory=lambda: stys.FunctionType(input=[BOOL_T], output=[BOOL_T]) + ) - def incoming(self) -> list[Wire]: - return [self.a, self.b] + def __call__(self, a: Wire) -> Command: + return super().__call__(a) - def num_out(self) -> int | None: - return 2 - def op(self) -> Op: - return DummyOp( - sops.CustomOp( - parent=-1, - extension="arithmetic.int", - op_name="idivmod_u", - signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), - args=[ARG_5, ARG_5], - ) - ) +Not = NotDef() @dataclass -class MakeTuple(Command): - types: list[stys.Type] - wires: list[Wire] - - def incoming(self) -> list[Wire]: - return self.wires - - def num_out(self) -> int | None: - return 1 - - def op(self) -> Op: - return DummyOp(sops.MakeTuple(parent=-1, tys=self.types)) +class IntOps(Custom): + extension: stys.ExtensionId = "arithmetic.int" @dataclass -class UnpackTuple(Command): - types: list[stys.Type] - wire: Wire - - def incoming(self) -> list[Wire]: - return [self.wire] +class DivModDef(IntOps): + num_out: int | None = 2 + extension: stys.ExtensionId = "arithmetic.int" + op_name: str = "idivmod_u" + signature: stys.FunctionType = field( + default_factory=lambda: stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) + ) + args: list[stys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) - def num_out(self) -> int | None: - return len(self.types) - def op(self) -> Op: - return DummyOp(sops.UnpackTuple(parent=-1, tys=self.types)) +DivMod = DivModDef() def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): @@ -117,9 +81,9 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): def test_stable_indices(): - h = Hugr(DummyOp(sops.DFG(parent=-1))) + h = Hugr(ops.DFG()) - nodes = [h.add_node(NOT_OP) for _ in range(3)] + nodes = [h.add_node(Not) for _ in range(3)] assert len(h) == 4 h.add_link(nodes[0].out(0), nodes[1].inp(0)) @@ -142,7 +106,7 @@ def test_stable_indices(): with pytest.raises(KeyError): _ = h[Node(46)] - new_n = h.add_node(NOT_OP) + new_n = h.add_node(Not) assert new_n == nodes[1] assert len(h) == 4 @@ -183,7 +147,7 @@ def test_multiport(): def test_add_op(): h = Dfg.endo([BOOL_T]) (a,) = h.inputs() - nt = h.add_op(NOT_OP, a) + nt = h.add_op(Not, a) h.set_outputs(nt) _validate(h.hugr) @@ -193,16 +157,16 @@ def test_tuple(): row = [BOOL_T, QB_T] h = Dfg.endo(row) a, b = h.inputs() - t = h.add(MakeTuple(row, [a, b])) - a, b = h.add(UnpackTuple(row, t)) + t = h.add(ops.MakeTuple(row)(a, b)) + a, b = h.add(ops.UnpackTuple(row)(t)) h.set_outputs(a, b) _validate(h.hugr) h1 = Dfg.endo(row) a, b = h1.inputs() - mt = h1.add_op(DummyOp(sops.MakeTuple(parent=-1, tys=row)), a, b) - a, b = h1.add_op(DummyOp(sops.UnpackTuple(parent=-1, tys=row)), mt)[0, 1] + mt = h1.add_op(ops.MakeTuple(row), a, b) + a, b = h1.add_op(ops.UnpackTuple(row), mt)[0, 1] h1.set_outputs(a, b) assert h.hugr.to_serial() == h1.hugr.to_serial() @@ -224,7 +188,7 @@ def test_insert(): assert len(h1.hugr) == 4 - new_h = Hugr(DummyOp(sops.DFG(parent=-1))) + new_h = Hugr(ops.DFG()) mapping = h1.hugr.insert_hugr(new_h, h1.hugr.root) assert mapping == {new_h.root: Node(4)}