Skip to content

Commit

Permalink
feat(hugr-py): add builders for Conditional and TailLoop (#1210)
Browse files Browse the repository at this point in the history
Closes #1204 

Including a short hand for if/else style conditionals

Includes some other simplifications and refactors, reccommend going
commit by commit
  • Loading branch information
ss2165 authored Jun 21, 2024
1 parent f7ea178 commit 43569a4
Show file tree
Hide file tree
Showing 9 changed files with 514 additions and 69 deletions.
16 changes: 8 additions & 8 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass, replace
from dataclasses import dataclass

import hugr._ops as ops

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, TypeRow, Type
from ._tys import TypeRow, Type
import hugr._val as val


Expand All @@ -16,7 +16,7 @@ def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def set_single_succ_outputs(self, *outputs: Wire) -> None:
u = self.add_load_const(val.Unit)
u = self.load(val.Unit)
self.set_outputs(u, *outputs)

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
Expand Down Expand Up @@ -47,7 +47,7 @@ class Cfg(ParentBuilder[ops.CFG]):
exit: Node

def __init__(self, input_types: TypeRow) -> None:
root_op = ops.CFG(FunctionType(input=input_types, output=[]))
root_op = ops.CFG(inputs=input_types)
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, input_types)

Expand All @@ -68,7 +68,7 @@ def new_nested(
) -> Cfg:
new = cls.__new__(cls)
root = hugr.add_node(
ops.CFG(FunctionType(input=input_types, output=[])),
ops.CFG(inputs=input_types),
parent or hugr.root,
)
new._init_impl(hugr, root, input_types)
Expand Down Expand Up @@ -97,6 +97,8 @@ def add_block(self, input_types: TypeRow) -> Block:
)
return new_block

# TODO insert_block

def add_successor(self, pred: Wire) -> Block:
b = self.add_block(self._nth_outputs(pred))

Expand Down Expand Up @@ -125,6 +127,4 @@ def branch_exit(self, src: Wire) -> None:
raise MismatchedExit(src.node.idx)
else:
self._exit_op._cfg_outputs = out_types
self.parent_op.signature = replace(
self.parent_op.signature, output=out_types
)
self.parent_op._outputs = out_types
110 changes: 110 additions & 0 deletions hugr-py/src/hugr/_cond_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

from dataclasses import dataclass

import hugr._ops as ops

from ._dfg import _DfBase
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import Sum, TypeRow


class Case(_DfBase[ops.Case]):
_parent_cond: Conditional | None = None

def set_outputs(self, *outputs: Wire) -> None:
super().set_outputs(*outputs)
if self._parent_cond is not None:
self._parent_cond._update_outputs(self._wire_types(outputs))


class ConditionalError(Exception):
pass


@dataclass
class _IfElse(Case):
def __init__(self, case: Case) -> None:
self.hugr = case.hugr
self.parent_node = case.parent_node
self.input_node = case.input_node
self.output_node = case.output_node
self._parent_cond = case._parent_cond

def _parent_conditional(self) -> Conditional:
if self._parent_cond is None:
raise ConditionalError("If must have a parent conditional.")
return self._parent_cond


class If(_IfElse):
def add_else(self) -> Else:
return Else(self._parent_conditional().add_case(0))


class Else(_IfElse):
def finish(self) -> Node:
return self._parent_conditional().parent_node


@dataclass
class Conditional(ParentBuilder[ops.Conditional]):
cases: dict[int, Node | None]

def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
root_op = ops.Conditional(sum_ty, other_inputs)
hugr = Hugr(root_op)
self._init_impl(hugr, hugr.root, len(sum_ty.variant_rows))

def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None:
self.hugr = hugr
self.parent_node = root
self.cases = {i: None for i in range(n_cases)}

@classmethod
def new_nested(
cls,
sum_ty: Sum,
other_inputs: TypeRow,
hugr: Hugr,
parent: ToNode | None = None,
) -> Conditional:
new = cls.__new__(cls)
root = hugr.add_node(
ops.Conditional(sum_ty, other_inputs),
parent or hugr.root,
)
new._init_impl(hugr, root, len(sum_ty.variant_rows))
return new

def _update_outputs(self, outputs: TypeRow) -> None:
if self.parent_op._outputs is None:
self.parent_op._outputs = outputs
else:
if outputs != self.parent_op._outputs:
raise ConditionalError("Mismatched case outputs.")

def add_case(self, case_id: int) -> Case:
if case_id not in self.cases:
raise ConditionalError(f"Case {case_id} out of possible range.")
input_types = self.parent_op.nth_inputs(case_id)
new_case = Case.new_nested(
ops.Case(input_types),
self.hugr,
self.parent_node,
)
new_case._parent_cond = self
self.cases[case_id] = new_case.parent_node
return new_case

# TODO insert_case


@dataclass
class TailLoop(_DfBase[ops.TailLoop]):
def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None:
root_op = ops.TailLoop(just_inputs, rest)
super().__init__(root_op)

def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None:
self.set_outputs(sum_wire, *rest)
74 changes: 55 additions & 19 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
from typing import (
TYPE_CHECKING,
Iterable,
Sequence,
TypeVar,
)

from typing_extensions import Self

import hugr._ops as ops
import hugr._val as val
from hugr._tys import Type, TypeRow
from hugr._tys import Type, TypeRow, get_first_sum

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire

if TYPE_CHECKING:
from ._cfg import Cfg
from ._cond_loop import Conditional, If, TailLoop


DP = TypeVar("DP", bound=ops.DfParentOp)
Expand Down Expand Up @@ -72,39 +74,73 @@ def add_op(self, op: ops.DataflowOp, /, *args: Wire) -> Node:
def add(self, com: ops.Command) -> Node:
return self.add_op(com.op, *com.incoming)

def _insert_nested_impl(self, builder: ParentBuilder, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(builder.hugr, self.parent_node)
self._wire_up(mapping[builder.parent_node], args)
return mapping[builder.parent_node]

def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(dfg.hugr, self.parent_node)
self._wire_up(mapping[dfg.parent_node], args)
return mapping[dfg.parent_node]
return self._insert_nested_impl(dfg, *args)

def add_nested(
self,
*args: Wire,
) -> Dfg:
from ._dfg import Dfg

input_types = [self._get_dataflow_type(w) for w in args]

parent_op = ops.DFG(list(input_types))
parent_op = ops.DFG(self._wire_types(args))
dfg = Dfg.new_nested(parent_op, self.hugr, self.parent_node)
self._wire_up(dfg.parent_node, args)
return dfg

def _wire_types(self, args: Iterable[Wire]) -> TypeRow:
return [self._get_dataflow_type(w) for w in args]

def add_cfg(
self,
input_types: TypeRow,
*args: Wire,
) -> Cfg:
from ._cfg import Cfg

cfg = Cfg.new_nested(input_types, self.hugr, self.parent_node)
cfg = Cfg.new_nested(self._wire_types(args), self.hugr, self.parent_node)
self._wire_up(cfg.parent_node, args)
return cfg

def insert_cfg(self, cfg: Cfg, *args: Wire) -> Node:
mapping = self.hugr.insert_hugr(cfg.hugr, self.parent_node)
self._wire_up(mapping[cfg.parent_node], args)
return mapping[cfg.parent_node]
return self._insert_nested_impl(cfg, *args)

def add_conditional(self, cond: Wire, *args: Wire) -> Conditional:
from ._cond_loop import Conditional

args = (cond, *args)
(sum_, other_inputs) = get_first_sum(self._wire_types(args))
cond = Conditional.new_nested(sum_, other_inputs, self.hugr, self.parent_node)
self._wire_up(cond.parent_node, args)
return cond

def insert_conditional(self, cond: Conditional, *args: Wire) -> Node:
return self._insert_nested_impl(cond, *args)

def add_if(self, cond: Wire, *args: Wire) -> If:
from ._cond_loop import If

conditional = self.add_conditional(cond, *args)
return If(conditional.add_case(1))

def add_tail_loop(
self, just_inputs: Sequence[Wire], rest: Sequence[Wire]
) -> TailLoop:
from ._cond_loop import TailLoop

just_input_types = self._wire_types(just_inputs)
rest_types = self._wire_types(rest)
parent_op = ops.TailLoop(just_input_types, rest_types)
tl = TailLoop.new_nested(parent_op, self.hugr, self.parent_node)
self._wire_up(tl.parent_node, (*just_inputs, *rest))
return tl

def insert_tail_loop(self, tl: TailLoop, *args: Wire) -> Node:
return self._insert_nested_impl(tl, *args)

def set_outputs(self, *args: Wire) -> None:
self._wire_up(self.output_node, args)
Expand All @@ -117,22 +153,22 @@ def add_state_order(self, src: Node, dst: Node) -> None:
def add_const(self, val: val.Value) -> Node:
return self.hugr.add_const(val, self.parent_node)

def load_const(self, const_node: ToNode) -> Node:
const_op = self.hugr._get_typed_op(const_node, ops.Const)
def load(self, const: ToNode | val.Value) -> Node:
if isinstance(const, val.Value):
const = self.add_const(const)
const_op = self.hugr._get_typed_op(const, ops.Const)
load_op = ops.LoadConst(const_op.val.type_())

load = self.add(load_op())
self.hugr.add_link(const_node.out_port(), load.inp(0))
self.hugr.add_link(const.out_port(), load.inp(0))

return load

def add_load_const(self, val: val.Value) -> Node:
return self.load_const(self.add_const(val))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
op.set_in_types(tys)
return tys

def _get_dataflow_type(self, wire: Wire) -> Type:
port = wire.out_port()
Expand Down
Loading

0 comments on commit 43569a4

Please sign in to comment.