Skip to content

Commit

Permalink
feat: builder ops separate from serialised ops
Browse files Browse the repository at this point in the history
no more parent=-1
  • Loading branch information
ss2165 committed May 30, 2024
1 parent 47f95c4 commit bcc9af1
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 211 deletions.
71 changes: 19 additions & 52 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing_extensions import Self

from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.ops import BaseOp, OpType as SerialOp
import hugr.serialization.ops as sops
from hugr.serialization.tys import Type
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.tys import Type, FunctionType
from hugr._ops import Op, Input, Output, DFG
from hugr.utils import BiMap


Expand All @@ -43,6 +43,13 @@ class Wire(Protocol):
def out_port(self) -> OutPort: ...


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None


@dataclass(frozen=True, eq=True, order=True)
class OutPort(_Port, Wire):
direction: ClassVar[Direction] = Direction.OUTGOING
Expand Down Expand Up @@ -100,35 +107,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort:
return self.out(offset)


class Op(Protocol):
def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ...

@classmethod
def from_serial(cls, serial: SerialOp) -> Self: ...


T = TypeVar("T", bound=BaseOp)


@dataclass()
class DummyOp(Op, Generic[T]):
_serial_op: T

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=self._serial_op.model_copy()) # type: ignore

@classmethod
def from_serial(cls, serial: SerialOp) -> DummyOp:
return DummyOp(serial.root)


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None


@dataclass()
class NodeData:
op: Op
Expand All @@ -138,10 +116,9 @@ class NodeData:
# TODO children field?

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
o = self.op.to_serial(node, hugr)
o.root.parent = self.parent.idx if self.parent else node.idx
o = self.op.to_serial(node, self.parent if self.parent else node, hugr)

return o
return SerialOp(root=o) # type: ignore


P = TypeVar("P", InPort, OutPort)
Expand Down Expand Up @@ -371,7 +348,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr:
hugr.root = Node(idx)
parent = None
serial_node.root.parent = -1
hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent))
hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent))

for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
if src_offset is None or dst_offset is None:
Expand All @@ -395,35 +372,25 @@ def __init__(
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DummyOp(sops.DFG(parent=-1))
root_op._serial_op.signature.input = input_types
root_op._serial_op.signature.output = output_types
root_op = DFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
DummyOp(sops.Input(parent=0, types=input_types)),
self.root,
len(input_types),
)
self.output_node = self.hugr.add_node(
DummyOp(sops.Output(parent=0, types=output_types)), self.root
Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> DummyOp[sops.Input]:
def _input_op(self) -> Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, DummyOp)
assert isinstance(dop._serial_op, sops.Input)
assert isinstance(dop, Input)
return dop

def inputs(self) -> list[OutPort]:
return [
self.input_node.out(i)
for i in range(len(self._input_op()._serial_op.types))
]
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
Expand Down
98 changes: 98 additions & 0 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Generic, TypeVar, TYPE_CHECKING
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
import hugr.serialization.tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node


class Op(ABC):
@abstractmethod
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ...


T = TypeVar("T", bound=BaseOp)


@dataclass()
class SerWrap(Op, Generic[T]):
# catch all for serial ops that don't have a corresponding Op class
_serial_op: T

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp:
root = self._serial_op.model_copy()
root.parent = parent.idx
return root


@dataclass()
class Input(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=self.types)


@dataclass()
class Output(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=self.types)


@dataclass()
class Custom(Op):
extension: tys.ExtensionId
op_name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
description: str = ""
args: list[tys.TypeArg] = field(default_factory=list)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
return sops.CustomOp(
parent=parent.idx,
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
description=self.description,
args=self.args,
)


@dataclass()
class MakeTuple(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=self.types,
)


@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
return sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
)


@dataclass()
class DFG(Op):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG:
return sops.DFG(
parent=parent.idx,
signature=self.signature,
)
Empty file added hugr-py/src/hugr/_types.py
Empty file.
135 changes: 0 additions & 135 deletions hugr-py/src/hugr/ops.py

This file was deleted.

Loading

0 comments on commit bcc9af1

Please sign in to comment.