Skip to content

Commit

Permalink
feat: PartialOp protocol for cleaner op mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 18, 2024
1 parent e8b3a17 commit cf135eb
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
6 changes: 3 additions & 3 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
self.parent_op()._set_out_types(self._output_op().types)
self.parent_op()._set_out_types(self._output_op()._types())

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.DataflowOp):
op._set_in_types(tys)
if isinstance(op := self.hugr[node].op, ops.PartialOp):
op.set_in_types(tys)

def _get_dataflow_type(self, wire: Wire) -> Type:
port = wire.out_port()
Expand Down
5 changes: 5 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,8 @@ def msg(self):

class ParentBeforeChild(Exception):
msg: str = "Parent node must be added before child node."


@dataclass
class IncompleteOp(Exception):
msg: str = "Operation is incomplete, may require set_in_types to be called."
51 changes: 34 additions & 17 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hugr.serialization.ops as sops
from hugr.utils import ser_it
import hugr._tys as tys
from ._exceptions import IncompleteOp

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire, _Port

Check warning on line 12 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L12

Added line #L12 was not covered by tests
Expand Down Expand Up @@ -39,13 +40,15 @@ def port_type(self, port: _Port) -> tys.Type:
return sig.input[port.offset]

Check warning on line 40 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L40

Added line #L40 was not covered by tests
return sig.output[port.offset]

def _set_in_types(self, types: tys.TypeRow) -> None:
return

def __call__(self, *args) -> Command:
return Command(self, list(args))


@runtime_checkable
class PartialOp(DataflowOp, Protocol):
def set_in_types(self, types: tys.TypeRow) -> None: ...


@dataclass(frozen=True)
class Command:
op: DataflowOp
Expand Down Expand Up @@ -88,16 +91,21 @@ def __call__(self) -> Command:


@dataclass()
class Output(DataflowOp):
types: tys.TypeRow = field(default_factory=list)
class Output(PartialOp):
types: tys.TypeRow | None = None

def _types(self) -> tys.TypeRow:
if self.types is None:
raise IncompleteOp()

Check warning on line 99 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L99

Added line #L99 was not covered by tests
return self.types

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=ser_it(self.types))
return sops.Output(parent=parent.idx, types=ser_it(self._types()))

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=self.types, output=[])
return tys.FunctionType(input=self._types(), output=[])

Check warning on line 106 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L106

Added line #L106 was not covered by tests

def _set_in_types(self, types: tys.TypeRow) -> None:
def set_in_types(self, types: tys.TypeRow) -> None:
self.types = types


Expand Down Expand Up @@ -128,33 +136,43 @@ def outer_signature(self) -> tys.FunctionType:


@dataclass()
class MakeTupleDef(DataflowOp):
types: tys.TypeRow = field(default_factory=list)
class MakeTupleDef(PartialOp):
types: tys.TypeRow | None = None
num_out: int | None = 1

def _types(self) -> tys.TypeRow:
if self.types is None:
raise IncompleteOp()

Check warning on line 145 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L145

Added line #L145 was not covered by tests
return self.types

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=ser_it(self.types),
tys=ser_it(self._types()),
)

def __call__(self, *elements: Wire) -> Command:
return super().__call__(*elements)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=self.types, output=[tys.Tuple(*self.types)])
return tys.FunctionType(input=self._types(), output=[tys.Tuple(*self._types())])

def _set_in_types(self, types: tys.TypeRow) -> None:
def set_in_types(self, types: tys.TypeRow) -> None:
self.types = types


MakeTuple = MakeTupleDef()


@dataclass()
class UnpackTupleDef(DataflowOp):
class UnpackTupleDef(PartialOp):
types: tys.TypeRow = field(default_factory=list)

def _types(self) -> tys.TypeRow:
if self.types is None:
raise IncompleteOp()
return self.types

Check warning on line 174 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L172-L174

Added lines #L172 - L174 were not covered by tests

@property
def num_out(self) -> int | None:
return len(self.types)
Expand All @@ -171,7 +189,7 @@ def __call__(self, tuple_: Wire) -> Command:
def outer_signature(self) -> tys.FunctionType:
return MakeTupleDef(self.types).outer_signature().flip()

def _set_in_types(self, types: tys.TypeRow) -> None:
def set_in_types(self, types: tys.TypeRow) -> None:
(t,) = types
assert isinstance(t, tys.Sum), f"Expected unary Sum, got {t}"
(row,) = t.variant_rows
Expand Down Expand Up @@ -203,8 +221,7 @@ def outer_signature(self) -> tys.FunctionType:
class DfParentOp(Op, Protocol):
def inner_signature(self) -> tys.FunctionType: ...

def _set_out_types(self, types: tys.TypeRow) -> None:
return
def _set_out_types(self, types: tys.TypeRow) -> None: ...


@dataclass()
Expand Down

0 comments on commit cf135eb

Please sign in to comment.