diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index ad79eb201..a6420f6c1 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -19,7 +19,7 @@ from hugr.serialization.serial_hugr import SerialHugr from hugr.serialization.ops import OpType as SerialOp from hugr.serialization.tys import Type, FunctionType -from hugr._ops import Op, Input, Output, DFG +from hugr._ops import Op, Input, Output, DFG, Command from hugr.utils import BiMap @@ -44,13 +44,6 @@ class Wire(Protocol): def out_port(self) -> OutPort: ... -class Command(Protocol): - def op(self) -> Op: ... - def incoming(self) -> Iterable[Wire]: ... - def num_out(self) -> int | None: - return None - - @dataclass(frozen=True, eq=True, order=True) class OutPort(_Port, Wire): direction: ClassVar[Direction] = Direction.OUTGOING @@ -399,7 +392,7 @@ def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node: return new_n def add(self, com: Command) -> Node: - return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out()) + return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out) def insert_nested(self, dfg: Dfg, *args: Wire) -> Node: mapping = self.hugr.insert_hugr(dfg.hugr, self.root) diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index d6d8dc691..37b934bc0 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -7,12 +7,25 @@ import hugr.serialization.tys as tys if TYPE_CHECKING: - from hugr._hugr import Hugr, Node + from hugr._hugr import Hugr, Node, Wire class Op(Protocol): + @property + def num_out(self) -> int | None: + return None + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ... + def __call__(self, *args) -> Command: + return Command(self, list(args)) + + +@dataclass(frozen=True) +class Command: + op: Op + incoming: list[Wire] + T = TypeVar("T", bound=BaseOp) @@ -35,6 +48,9 @@ class Input(Op): def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: return sops.Input(parent=parent.idx, types=self.types) + def __call__(self) -> Command: + return super().__call__() + @dataclass() class Output(Op): @@ -46,10 +62,10 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: @dataclass() class Custom(Op): - extension: tys.ExtensionId op_name: str signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty) description: str = "" + extension: tys.ExtensionId = "" args: list[tys.TypeArg] = field(default_factory=list) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: @@ -66,6 +82,7 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: @dataclass() class MakeTuple(Op): types: list[tys.Type] + num_out: int | None = 1 def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: return sops.MakeTuple( @@ -73,17 +90,27 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: tys=self.types, ) + def __call__(self, *elements: Wire) -> Command: + return super().__call__(*elements) + @dataclass() class UnpackTuple(Op): types: list[tys.Type] + @property + def num_out(self) -> int | None: + return len(self.types) + def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: return sops.UnpackTuple( parent=parent.idx, tys=self.types, ) + def __call__(self, tuple_: Wire) -> Command: + return super().__call__(tuple_) + @dataclass() class DFG(Op): diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 3f1971083..5c78a145b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -1,10 +1,10 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field import subprocess import os import pathlib -from hugr._hugr import Dfg, Hugr, Node, Command, Wire -from hugr._ops import Op, Custom +from hugr._hugr import Dfg, Hugr, Node, Wire +from hugr._ops import Custom, Command import hugr._ops as ops from hugr.serialization import SerialHugr import hugr.serialization.tys as stys @@ -23,76 +23,45 @@ ) ) -# TODO get from YAML -NOT_OP = Custom( - extension="logic", - op_name="Not", - signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]), -) - @dataclass -class Not(Command): - a: Wire - - def incoming(self) -> list[Wire]: - return [self.a] - - def num_out(self) -> int | None: - return 1 - - def op(self) -> Op: - return NOT_OP +class LogicOps(Custom): + extension: stys.ExtensionId = "logic" +# TODO get from YAML @dataclass -class DivMod(Command): - a: Wire - b: Wire +class NotDef(LogicOps): + num_out: int | None = 1 + op_name: str = "Not" + signature: stys.FunctionType = field( + default_factory=lambda: stys.FunctionType(input=[BOOL_T], output=[BOOL_T]) + ) - def incoming(self) -> list[Wire]: - return [self.a, self.b] + def __call__(self, a: Wire) -> Command: + return super().__call__(a) - def num_out(self) -> int | None: - return 2 - def op(self) -> Op: - return Custom( - extension="arithmetic.int", - op_name="idivmod_u", - signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2), - args=[ARG_5, ARG_5], - ) +Not = NotDef() @dataclass -class MakeTuple(Command): - types: list[stys.Type] - wires: list[Wire] - - def incoming(self) -> list[Wire]: - return self.wires - - def num_out(self) -> int | None: - return 1 - - def op(self) -> Op: - return ops.MakeTuple(self.types) +class IntOps(Custom): + extension: stys.ExtensionId = "arithmetic.int" @dataclass -class UnpackTuple(Command): - types: list[stys.Type] - wire: Wire - - def incoming(self) -> list[Wire]: - return [self.wire] +class DivModDef(IntOps): + num_out: int | None = 2 + extension: stys.ExtensionId = "arithmetic.int" + op_name: str = "idivmod_u" + signature: stys.FunctionType = field( + default_factory=lambda: stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) + ) + args: list[stys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) - def num_out(self) -> int | None: - return len(self.types) - def op(self) -> Op: - return ops.UnpackTuple(self.types) +DivMod = DivModDef() def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): @@ -114,7 +83,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): def test_stable_indices(): h = Hugr(ops.DFG()) - nodes = [h.add_node(NOT_OP) for _ in range(3)] + nodes = [h.add_node(Not) for _ in range(3)] assert len(h) == 4 h.add_link(nodes[0].out(0), nodes[1].inp(0)) @@ -137,7 +106,7 @@ def test_stable_indices(): with pytest.raises(KeyError): _ = h[Node(46)] - new_n = h.add_node(NOT_OP) + new_n = h.add_node(Not) assert new_n == nodes[1] assert len(h) == 4 @@ -178,7 +147,7 @@ def test_multiport(): def test_add_op(): h = Dfg.endo([BOOL_T]) (a,) = h.inputs() - nt = h.add_op(NOT_OP, a) + nt = h.add_op(Not, a) h.set_outputs(nt) _validate(h.hugr) @@ -188,8 +157,8 @@ def test_tuple(): row = [BOOL_T, QB_T] h = Dfg.endo(row) a, b = h.inputs() - t = h.add(MakeTuple(row, [a, b])) - a, b = h.add(UnpackTuple(row, t)) + t = h.add(ops.MakeTuple(row)(a, b)) + a, b = h.add(ops.UnpackTuple(row)(t)) h.set_outputs(a, b) _validate(h.hugr)