Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): builder ops separate from serialised ops #1140

Merged
merged 5 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 13 additions & 53 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing_extensions import Self

from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.ops import BaseOp, OpType as SerialOp
import hugr.serialization.ops as sops
from hugr.serialization.tys import Type
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.tys import Type, FunctionType
from hugr._ops import Op, Input, Output, DFG, Command
from hugr.utils import BiMap


Expand Down Expand Up @@ -101,35 +101,6 @@ def port(self, offset: int, direction: Direction) -> InPort | OutPort:
return self.out(offset)


class Op(Protocol):
def to_serial(self, node: Node, hugr: Hugr) -> SerialOp: ...

@classmethod
def from_serial(cls, serial: SerialOp) -> Self: ...


T = TypeVar("T", bound=BaseOp)


@dataclass()
class DummyOp(Op, Generic[T]):
_serial_op: T

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=self._serial_op.model_copy()) # type: ignore

@classmethod
def from_serial(cls, serial: SerialOp) -> DummyOp:
return DummyOp(serial.root)


class Command(Protocol):
def op(self) -> Op: ...
def incoming(self) -> Iterable[Wire]: ...
def num_out(self) -> int | None:
return None


@dataclass()
class NodeData:
op: Op
Expand All @@ -139,10 +110,9 @@ class NodeData:
# TODO children field?

def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
o = self.op.to_serial(node, hugr)
o.root.parent = self.parent.idx if self.parent else node.idx
o = self.op.to_serial(node, self.parent if self.parent else node, hugr)

return o
return SerialOp(root=o) # type: ignore[arg-type]


P = TypeVar("P", InPort, OutPort)
Expand Down Expand Up @@ -372,7 +342,7 @@ def from_serial(cls, serial: SerialHugr) -> Hugr:
hugr.root = Node(idx)
parent = None
serial_node.root.parent = -1
hugr._nodes.append(NodeData(DummyOp.from_serial(serial_node), parent))
hugr._nodes.append(NodeData(serial_node.root.deserialize(), parent))

for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
if src_offset is None or dst_offset is None:
Expand All @@ -396,43 +366,33 @@ def __init__(
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DummyOp(sops.DFG(parent=-1))
root_op._serial_op.signature.input = input_types
root_op._serial_op.signature.output = output_types
root_op = DFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
DummyOp(sops.Input(parent=0, types=input_types)),
self.root,
len(input_types),
)
self.output_node = self.hugr.add_node(
DummyOp(sops.Output(parent=0, types=output_types)), self.root
Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> DummyOp[sops.Input]:
def _input_op(self) -> Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, DummyOp)
assert isinstance(dop._serial_op, sops.Input)
assert isinstance(dop, Input)
return dop

def inputs(self) -> list[OutPort]:
return [
self.input_node.out(i)
for i in range(len(self._input_op()._serial_op.types))
]
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
return self.add_op(com.op(), *com.incoming(), num_outs=com.num_out())
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
Expand Down
123 changes: 123 additions & 0 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
import hugr.serialization.tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire

Check warning on line 10 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L10

Added line #L10 was not covered by tests


class Op(Protocol):
@property
def num_out(self) -> int | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unset for many op definitions (Input, SerWrap, Custom, DFG)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can't be known in general - but can be for many of those! (which I've added). Am unsure about Custom but I've added it for now. It should be left as the default None for SerWrap.

return None

Check warning on line 16 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L16

Added line #L16 was not covered by tests

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ...

def __call__(self, *args) -> Command:
return Command(self, list(args))


@dataclass(frozen=True)
class Command:
op: Op
incoming: list[Wire]


T = TypeVar("T", bound=BaseOp)


@dataclass()
class SerWrap(Op, Generic[T]):
# catch all for serial ops that don't have a corresponding Op class
_serial_op: T

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:
root = self._serial_op.model_copy()
root.parent = parent.idx
return root

Check warning on line 41 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L39-L41

Added lines #L39 - L41 were not covered by tests


@dataclass()
class Input(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=self.types)

def __call__(self) -> Command:
return super().__call__()

Check warning on line 52 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L52

Added line #L52 was not covered by tests


@dataclass()
class Output(Op):
types: list[tys.Type]

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=self.types)


@dataclass()
class Custom(Op):
op_name: str
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)
description: str = ""
extension: tys.ExtensionId = ""
args: list[tys.TypeArg] = field(default_factory=list)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
return sops.CustomOp(
parent=parent.idx,
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
description=self.description,
args=self.args,
)


@dataclass()
class MakeTuple(Op):
types: list[tys.Type]
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=self.types,
)

def __call__(self, *elements: Wire) -> Command:
return super().__call__(*elements)


@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]

@property
def num_out(self) -> int | None:
return len(self.types)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
return sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
)

def __call__(self, tuple_: Wire) -> Command:
return super().__call__(tuple_)


@dataclass()
class DFG(Op):
signature: tys.FunctionType = field(default_factory=tys.FunctionType.empty)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG:
return sops.DFG(
parent=parent.idx,
signature=self.signature,
)
31 changes: 31 additions & 0 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import inspect
import sys
from abc import ABC
Expand Down Expand Up @@ -41,6 +42,10 @@
"""Name of the op for visualisation"""
return self.__class__.__name__

def deserialize(self) -> ops.Op:
"""Deserializes the model into the corresponding Op."""
return ops.SerWrap(self)

Check warning on line 47 in hugr-py/src/hugr/serialization/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/serialization/ops.py#L47

Added line #L47 was not covered by tests


# ----------------------------------------------------------
# --------------- Module level operations ------------------
Expand Down Expand Up @@ -209,6 +214,9 @@
assert len(in_types) == 0
self.types = list(out_types)

def deserialize(self) -> ops.Input:
return ops.Input(types=self.types)


class Output(DataflowOp):
"""An output node. The inputs are the outputs of the function."""
Expand All @@ -220,6 +228,9 @@
assert len(out_types) == 0
self.types = list(in_types)

def deserialize(self) -> ops.Output:
return ops.Output(types=self.types)


class Call(DataflowOp):
"""
Expand Down Expand Up @@ -292,6 +303,9 @@
input=list(inputs), output=list(outputs), extension_reqs=ExtensionSet([])
)

def deserialize(self) -> ops.DFG:
return ops.DFG(self.signature)


# ------------------------------------------------
# --------------- ControlFlowOp ------------------
Expand Down Expand Up @@ -388,6 +402,14 @@
def display_name(self) -> str:
return self.op_name

def deserialize(self) -> ops.Custom:
return ops.Custom(
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
args=self.args,
)

model_config = ConfigDict(
# Needed to avoid random '\n's in the pydantic description
json_schema_extra={
Expand Down Expand Up @@ -424,6 +446,9 @@
in_types = []
self.tys = list(in_types)

def deserialize(self) -> ops.MakeTuple:
return ops.MakeTuple(self.tys)


class UnpackTuple(DataflowOp):
"""An operation that packs all its inputs into a tuple."""
Expand All @@ -434,6 +459,9 @@
def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
self.tys = list(out_types)

def deserialize(self) -> ops.UnpackTuple:
return ops.UnpackTuple(self.tys)


class Tag(DataflowOp):
"""An operation that creates a tagged sum value from one of its variants."""
Expand Down Expand Up @@ -529,3 +557,6 @@
)

tys_model_rebuild(dict(classes))

#
import hugr._ops as ops # noqa: E402 # needed to avoid circular imports
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh -.-

This re-exports hugr._ops as hugr.serialization.ops.ops, so it ends up as part of the public API.
You should leave it as

from hugr import _ops

for now.

Loading
Loading