Skip to content

Commit

Permalink
refactor(hugr-py): move dfg to own file
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 5, 2024
1 parent 342eda3 commit 868a418
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 74 deletions.
78 changes: 78 additions & 0 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence, Iterable
from ._hugr import Hugr, Node, Wire, OutPort

from ._ops import Op, Command, Input, Output, DFG
from hugr.serialization.tys import FunctionType, Type


@dataclass()
class Dfg:
hugr: Hugr
root: Node
input_node: Node
output_node: Node

def __init__(
self, input_types: Sequence[Type], output_types: Sequence[Type]
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, Input)
return dop

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], args)
return mapping[dfg.root]

def add_nested(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
*args: Wire,
) -> Dfg:
dfg = self.hugr.add_dfg(input_types, output_types)
self._wire_up(dfg.root, args)
return dfg

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
# breaks if further edges are added
self.hugr.add_link(
src.out(self.hugr.num_outgoing(src)), dst.inp(self.hugr.num_incoming(dst))
)

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
self.hugr.add_link(src, node.inp(i))
81 changes: 8 additions & 73 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,20 @@
cast,
overload,
ClassVar,
TYPE_CHECKING,
)

from typing_extensions import Self

from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.tys import Type, FunctionType
from hugr._ops import Op, Input, Output, DFG, Command
from hugr.serialization.tys import Type
from hugr._ops import Op
from hugr.utils import BiMap

if TYPE_CHECKING:
from ._dfg import Dfg

Check warning on line 27 in hugr-py/src/hugr/_hugr.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_hugr.py#L27

Added line #L27 was not covered by tests


class Direction(Enum):
INCOMING = 0
Expand Down Expand Up @@ -304,6 +308,8 @@ def insert_hugr(self, hugr: Hugr, parent: Node | None = None) -> dict[Node, Node
return mapping

def add_dfg(self, input_types: Sequence[Type], output_types: Sequence[Type]) -> Dfg:
from ._dfg import Dfg

dfg = Dfg(input_types, output_types)
mapping = self.insert_hugr(dfg.hugr, self.root)
dfg.hugr = self
Expand Down Expand Up @@ -352,74 +358,3 @@ def from_serial(cls, serial: SerialHugr) -> Hugr:
)

return hugr


@dataclass()
class Dfg:
hugr: Hugr
root: Node
input_node: Node
output_node: Node

def __init__(
self, input_types: Sequence[Type], output_types: Sequence[Type]
) -> None:
input_types = list(input_types)
output_types = list(output_types)
root_op = DFG(FunctionType(input=input_types, output=output_types))
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.input_node = self.hugr.add_node(
Input(input_types), self.root, len(input_types)
)
self.output_node = self.hugr.add_node(Output(output_types), self.root)

@classmethod
def endo(cls, types: Sequence[Type]) -> Dfg:
return Dfg(types, types)

def _input_op(self) -> Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, Input)
return dop

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
self._wire_up(new_n, args)
return new_n

def add(self, com: Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], args)
return mapping[dfg.root]

def add_nested(
self,
input_types: Sequence[Type],
output_types: Sequence[Type],
*args: Wire,
) -> Dfg:
dfg = self.hugr.add_dfg(input_types, output_types)
self._wire_up(dfg.root, args)
return dfg

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
# breaks if further edges are added
self.hugr.add_link(
src.out(self.hugr.num_outgoing(src)), dst.inp(self.hugr.num_incoming(dst))
)

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
self.hugr.add_link(src, node.inp(i))
3 changes: 2 additions & 1 deletion hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import subprocess
import os
import pathlib
from hugr._hugr import Dfg, Hugr, Node, Wire
from hugr._hugr import Hugr, Node, Wire
from hugr._dfg import Dfg
from hugr._ops import Custom, Command
import hugr._ops as ops
from hugr.serialization import SerialHugr
Expand Down

0 comments on commit 868a418

Please sign in to comment.