Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): only require input type annotations when building #1199

Merged
merged 17 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 55 additions & 59 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Sequence
from dataclasses import dataclass, replace

import hugr._ops as ops

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, Sum, TypeRow
from ._tys import FunctionType, TypeRow, Type


class Block(_DfBase[ops.DataflowBlock]):
Expand All @@ -19,99 +18,96 @@
# TODO requires constants
raise NotImplementedError

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
src = p.out_port()
cfg_node = self.hugr[self.root].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
self._wire_up_port(node, i, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(i))
def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
cfg_node = self.hugr[self.parent_node].parent
assert cfg_node is not None
src_parent = self.hugr[src.node].parent
try:
super()._wire_up_port(node, offset, p)
except NoSiblingAncestor:
# note this just checks if there is a common CFG ancestor
# it does not check for valid dominance between basic blocks
# that is deferred to full HUGR validation.
while cfg_node != src_parent:
if src_parent is None or src_parent == self.hugr.root:
raise NotInSameCfg(src.node.idx, node.idx)

Check warning on line 34 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L34

Added line #L34 was not covered by tests
src_parent = self.hugr[src_parent].parent

self.hugr.add_link(src, node.inp(offset))
return self._get_dataflow_type(src)


@dataclass
class Cfg(ParentBuilder):
class Cfg(ParentBuilder[ops.CFG]):
hugr: Hugr
root: Node
parent_node: Node
_entry_block: Block
exit: Node

def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=output_types))
def __init__(self, input_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=[]))
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, input_types, output_types)
self._init_impl(hugr, hugr.root, input_types)

def _init_impl(
self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow, output_types: TypeRow
) -> None:
def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None:
self.hugr = hugr
self.root = root
self.parent_node = root
# to ensure entry is first child, add a dummy entry at the start
self._entry_block = Block.new_nested(
ops.DataflowBlock(input_types, []), hugr, root
)

self.exit = self.hugr.add_node(ops.ExitBlock(output_types), self.root)
self.exit = self.hugr.add_node(ops.ExitBlock([]), self.parent_node)

@classmethod
def new_nested(
cls,
input_types: TypeRow,
output_types: TypeRow,
hugr: Hugr,
parent: ToNode | None = None,
) -> Cfg:
new = cls.__new__(cls)
root = hugr.add_node(
ops.CFG(FunctionType(input=input_types, output=output_types)),
ops.CFG(FunctionType(input=input_types, output=[])),
parent or hugr.root,
)
new._init_impl(hugr, root, input_types, output_types)
new._init_impl(hugr, root, input_types)
return new

@property
def entry(self) -> Node:
return self._entry_block.root
return self._entry_block.parent_node

Check warning on line 80 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L80

Added line #L80 was not covered by tests

def _entry_op(self) -> ops.DataflowBlock:
dop = self.hugr[self.entry].op
assert isinstance(dop, ops.DataflowBlock)
return dop

def add_entry(self, sum_rows: Sequence[TypeRow], other_outputs: TypeRow) -> Block:
# update entry block types
self._entry_op().sum_rows = list(sum_rows)
self._entry_op().other_outputs = other_outputs
self._entry_block._output_op().types = [Sum(list(sum_rows)), *other_outputs]
return self._entry_block
return self.hugr._get_typed_op(self.entry, ops.DataflowBlock)

Check warning on line 83 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L83

Added line #L83 was not covered by tests

def _exit_op(self) -> ops.ExitBlock:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
return self.hugr._get_typed_op(self.exit, ops.ExitBlock)

def simple_entry(self, n_branches: int, other_outputs: TypeRow) -> Block:
return self.add_entry([[]] * n_branches, other_outputs)
def add_entry(self) -> Block:
return self._entry_block

def add_block(
self, input_types: TypeRow, sum_rows: Sequence[TypeRow], other_outputs: TypeRow
) -> Block:
def add_block(self, input_types: TypeRow) -> Block:
new_block = Block.new_nested(
ops.DataflowBlock(input_types, list(sum_rows), other_outputs),
ops.DataflowBlock(input_types, [], []),
self.hugr,
self.root,
self.parent_node,
)
return new_block

def simple_block(
self, input_types: TypeRow, n_branches: int, other_outputs: TypeRow
) -> Block:
return self.add_block(input_types, [[]] * n_branches, other_outputs)

def branch(self, src: Wire, dst: ToNode) -> None:
self.hugr.add_link(src.out_port(), dst.inp(0))
src = src.out_port()
self.hugr.add_link(src, dst.inp(0))

if dst == self.exit:
src_block = self.hugr._get_typed_op(src.node, ops.DataflowBlock)
out_types = [*src_block.sum_rows[src.offset], *src_block.other_outputs]
if self._exit_op().cfg_outputs:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
if self._exit_op().cfg_outputs != out_types:
raise MismatchedExit(src.node.idx)

Check warning on line 108 in hugr-py/src/hugr/_cfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_cfg.py#L108

Added line #L108 was not covered by tests
else:
self._exit_op().cfg_outputs = out_types
self.parent_op().signature = replace(
self.parent_op().signature, output=out_types
)
102 changes: 55 additions & 47 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import (
Iterable,
TYPE_CHECKING,
TypeVar,
)
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Iterable, TypeVar, cast
from typing_extensions import Self
import hugr._ops as ops
from hugr._tys import FunctionType, TypeRow

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, Wire, ToNode
from ._hugr import ToNode
from hugr._tys import Type

if TYPE_CHECKING:
from ._cfg import Cfg
Expand All @@ -17,74 +23,70 @@


@dataclass()
class _DfBase(ParentBuilder, Generic[DP]):
class _DfBase(ParentBuilder[DP]):
hugr: Hugr
root: Node
parent_node: Node
input_node: Node
output_node: Node

def __init__(self, root_op: DP) -> None:
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
self.hugr = Hugr(root_op)
self.root = self.hugr.root
self.parent_node = self.hugr.root
self._init_io_nodes(root_op)

def _init_io_nodes(self, root_op: DP):
input_types = root_op.input_types()
output_types = root_op.output_types()
inner_sig = root_op.inner_signature()

self.input_node = self.hugr.add_node(
ops.Input(input_types), self.root, len(input_types)
ops.Input(inner_sig.input), self.parent_node, len(inner_sig.input)
)
self.output_node = self.hugr.add_node(
ops.Output(inner_sig.output), self.parent_node
)
self.output_node = self.hugr.add_node(ops.Output(output_types), self.root)

@classmethod
def new_nested(cls, root_op: DP, hugr: Hugr, parent: ToNode | None = None) -> Self:
new = cls.__new__(cls)

new.hugr = hugr
new.root = hugr.add_node(root_op, parent or hugr.root)
new.parent_node = hugr.add_node(root_op, parent or hugr.root)
new._init_io_nodes(root_op)
return new

def _input_op(self) -> ops.Input:
dop = self.hugr[self.input_node].op
assert isinstance(dop, ops.Input)
return dop
return self.hugr._get_typed_op(self.input_node, ops.Input)

def _output_op(self) -> ops.Output:
dop = self.hugr[self.output_node].op
assert isinstance(dop, ops.Output)
return dop

def root_op(self) -> DP:
return cast(DP, self.hugr[self.root].op)
return self.hugr._get_typed_op(self.output_node, ops.Output)

def inputs(self) -> list[OutPort]:
return [self.input_node.out(i) for i in range(len(self._input_op().types))]

def add_op(self, op: ops.Op, /, *args: Wire, num_outs: int | None = None) -> Node:
new_n = self.hugr.add_node(op, self.root, num_outs=num_outs)
def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
new_n = self.hugr.add_node(op, self.parent_node)
self._wire_up(new_n, args)
return new_n

return replace(new_n, _num_out_ports=op.num_out)

def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming, num_outs=com.op.num_out)
return self.add_op(com.op, *com.incoming)

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.root)
self._wire_up(mapping[dfg.root], args)
return mapping[dfg.root]
mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node)
self._wire_up(mapping[dfg.parent_node], args)
return mapping[dfg.parent_node]

def add_nested(
self,
input_types: TypeRow,
output_types: TypeRow,
*args: Wire,
) -> Dfg:
from ._dfg import Dfg

root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
dfg = Dfg.new_nested(root_op, self.hugr, self.root)
self._wire_up(dfg.root, args)
input_types = [self._get_dataflow_type(w) for w in args]

root_op = ops.DFG(FunctionType(input=list(input_types), output=[]))
dfg = Dfg.new_nested(root_op, self.hugr, self.parent_node)
self._wire_up(dfg.parent_node, args)
return dfg

def add_cfg(
Expand All @@ -95,45 +97,51 @@
) -> Cfg:
from ._cfg import Cfg

ss2165 marked this conversation as resolved.
Show resolved Hide resolved
cfg = Cfg.new_nested(input_types, output_types, self.hugr, self.root)
self._wire_up(cfg.root, args)
cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.root)
self._wire_up(mapping[cfg.root], args)
return mapping[cfg.root]
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node)
self._wire_up(mapping[cfg.parent_node], args)
return mapping[cfg.parent_node]

Check warning on line 107 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L105-L107

Added lines #L105 - L107 were not covered by tests

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
self.parent_op()._set_out_types(self._output_op()._types())

def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
for i, p in enumerate(ports):
self._wire_up_port(node, i, p)

def _wire_up_port(self, node: Node, offset: int, p: Wire):
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
op.set_in_types(tys)

def _get_dataflow_type(self, wire: Wire) -> Type:
port = wire.out_port()
ty = self.hugr.port_type(port)
if ty is None:
raise ValueError(f"Port {port} is not a dataflow port.")

Check warning on line 126 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L126

Added line #L126 was not covered by tests
return ty

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
raise NoSiblingAncestor(src.node.idx, node.idx)
if node_ancestor != node:
self.add_state_order(src.node, node_ancestor)
self.hugr.add_link(src, node.inp(offset))
return self._get_dataflow_type(src)


class Dfg(_DfBase[ops.DFG]):
def __init__(self, input_types: TypeRow, output_types: TypeRow) -> None:
root_op = ops.DFG(FunctionType(input=input_types, output=output_types))
def __init__(self, *input_types: Type) -> None:
root_op = ops.DFG(FunctionType(input=list(input_types), output=[]))
super().__init__(root_op)

@classmethod
def endo(cls, types: TypeRow) -> Dfg:
return cls(types, types)


def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:
src_parent = h[src].parent
Expand Down
16 changes: 16 additions & 0 deletions hugr-py/src/hugr/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,21 @@
return f"Source {self.src} is not in the same CFG as target {self.tgt}, so cannot wire up."


@dataclass
class MismatchedExit(Exception):
src: int

@property
def msg(self):
return (

Check warning on line 30 in hugr-py/src/hugr/_exceptions.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_exceptions.py#L30

Added line #L30 was not covered by tests
f"Exit branch from node {self.src} does not match existing exit block type."
)


class ParentBeforeChild(Exception):
msg: str = "Parent node must be added before child node."


@dataclass
class IncompleteOp(Exception):
msg: str = "Operation is incomplete, may require set_in_types to be called."
Loading
Loading