Skip to content

Commit

Permalink
feat: builder ops separate from serialised ops
Browse files Browse the repository at this point in the history
no more parent=-1
  • Loading branch information
ss2165 committed May 30, 2024
1 parent ae162d2 commit a8e97a7
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 75 deletions.
69 changes: 18 additions & 51 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing_extensions import Self

from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.ops import BaseOp, OpType as SerialOp
import hugr.serialization.ops as sops
from hugr.serialization.tys import Type
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.tys import Type, FunctionType
from hugr._ops import Op, Input, Output, DFG
from hugr.utils import BiMap


Expand All @@ -43,6 +43,13 @@ class Wire(Protocol):
def out_port(self) -> OutPort: ...


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None

Check warning on line 50 in hugr-py/src/hugr/_hugr.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_hugr.py#L50

Added line #L50 was not covered by tests


@dataclass(frozen=True, eq=True, order=True)
class OutPort(_Port, Wire):
direction: ClassVar[Direction] = Direction.OUTGOING
Expand Down Expand Up @@ -100,35 +107,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort:
return self.out(offset)


class Op(Protocol):
def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ...

@classmethod
def from_serial(cls, serial: SerialOp) -> Self: ...


T = TypeVar("T", bound=BaseOp)


@dataclass()
class DummyOp(Op, Generic[T]):
_serial_op: T

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=self._serial_op.model_copy()) # type: ignore

@classmethod
def from_serial(cls, serial: SerialOp) -> DummyOp:
return DummyOp(serial.root)


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None


@dataclass()
class NodeData:
op: Op
Expand All @@ -138,8 +116,7 @@ class NodeData:
# TODO children field?

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
o = self.op.to_serial(node, hugr)
o.root.parent = self.parent.idx if self.parent else node.idx
o = self.op.to_serial(node, self.parent if self.parent else node, hugr)

return o

Expand Down Expand Up @@ -371,7 +348,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr:
hugr.root = Node(idx)
parent = None
serial_node.root.parent = -1
hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent))
hugr._nodes.append(NodeData(Op.from_serial(serial_node), parent))

for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
if src_offset is None or dst_offset is None:
Expand All @@ -395,35 +372,25 @@ def __init__(
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DummyOp(sops.DFG(parent=-1))
root_op._serial_op.signature.input = input_types
root_op._serial_op.signature.output = output_types
root_op = DFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
DummyOp(sops.Input(parent=0, types=input_types)),
self.root,
len(input_types),
)
self.output_node = self.hugr.add_node(
DummyOp(sops.Output(parent=0, types=output_types)), self.root
Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> DummyOp[sops.Input]:
def _input_op(self) -> Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, DummyOp)
assert isinstance(dop._serial_op, sops.Input)
assert isinstance(dop, Input)
return dop

def inputs(self) -> list[OutPort]:
return [
self.input_node.out(i)
for i in range(len(self._input_op()._serial_op.types))
]
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
Expand Down
File renamed without changes.
43 changes: 19 additions & 24 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import subprocess
import os
import pathlib
from hugr._hugr import Dfg, Hugr, DummyOp, Node, Command, Wire, Op
from hugr._hugr import Dfg, Hugr, Node, Command, Wire
from hugr._ops import Op, Custom
import hugr._ops as ops
from hugr.serialization import SerialHugr
import hugr.serialization.tys as stys
import hugr.serialization.ops as sops
import pytest
import json

Expand All @@ -22,14 +23,11 @@
)
)

NOT_OP = DummyOp(
# TODO get from YAML
sops.CustomOp(
parent=-1,
extension="logic",
op_name="Not",
signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]),
)
# TODO get from YAML
NOT_OP = Custom(
extension="logic",
op_name="Not",
signature=stys.FunctionType(input=[BOOL_T], output=[BOOL_T]),
)


Expand Down Expand Up @@ -59,14 +57,11 @@ def num_out(self) -> int | None:
return 2

def op(self) -> Op:
return DummyOp(
sops.CustomOp(
parent=-1,
extension="arithmetic.int",
op_name="idivmod_u",
signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2),
args=[ARG_5, ARG_5],
)
return Custom(
extension="arithmetic.int",
op_name="idivmod_u",
signature=stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2),
args=[ARG_5, ARG_5],
)


Expand All @@ -82,7 +77,7 @@ def num_out(self) -> int | None:
return 1

def op(self) -> Op:
return DummyOp(sops.MakeTuple(parent=-1, tys=self.types))
return ops.MakeTuple(self.types)


@dataclass
Expand All @@ -97,7 +92,7 @@ def num_out(self) -> int | None:
return len(self.types)

def op(self) -> Op:
return DummyOp(sops.UnpackTuple(parent=-1, tys=self.types))
return ops.UnpackTuple(self.types)


def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True):
Expand All @@ -117,7 +112,7 @@ def _validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True):


def test_stable_indices():
h = Hugr(DummyOp(sops.DFG(parent=-1)))
h = Hugr(ops.DFG())

nodes = [h.add_node(NOT_OP) for _ in range(3)]
assert len(h) == 4
Expand Down Expand Up @@ -201,8 +196,8 @@ def test_tuple():

h1 = Dfg.endo(row)
a, b = h1.inputs()
mt = h1.add_op(DummyOp(sops.MakeTuple(parent=-1, tys=row)), a, b)
a, b = h1.add_op(DummyOp(sops.UnpackTuple(parent=-1, tys=row)), mt)[0, 1]
mt = h1.add_op(ops.MakeTuple(row), a, b)
a, b = h1.add_op(ops.UnpackTuple(row), mt)[0, 1]
h1.set_outputs(a, b)

assert h.hugr.to_serial() == h1.hugr.to_serial()
Expand All @@ -224,7 +219,7 @@ def test_insert():

assert len(h1.hugr) == 4

new_h = Hugr(DummyOp(sops.DFG(parent=-1)))
new_h = Hugr(ops.DFG())
mapping = h1.hugr.insert_hugr(new_h, h1.hugr.root)
assert mapping == {new_h.root: Node(4)}

Expand Down

0 comments on commit a8e97a7

Please sign in to comment.