Skip to content

Commit

Permalink
feat!: Use the hugr builder (#366)
Browse files Browse the repository at this point in the history
This is an in-progress transplant of guppy from using an ad-hoc `Hugr`
definition into the builder from the `hugr` library.

The main mismatch between the old and new builders is that guppy used to
lazily add typed ports to the graph nodes as it needed. Since the new
builder requires operations to define their signature, we end up having
to edit the code everywhere from `guppylang.compiler` to the
`guppylang.prelude` definitions.

My first goal is to get this working, and then I will try and split some
changes from this PR (although most of it will have to be merged
monolithically).
 
Current test status:
- 390 $\color{green}\text{passed}$
- 14 $\color{yellow}\text{skipped}$
- 0 $\color{red}\text{failed}$

These issues currently breake some tests. I'll fix them before merging
the PR
- CQCL/hugr#1319
  ~~Required to store the module names in the hugr metadata~~
- CQCL/hugr#1424
  ~~Required for pytket and llvm integration~~
- ~~Release hugr and tket2 with the latest changes~~

Closes #257. Closes #85

BREAKING CHANGE: Removed `guppylang.hugr_builder.hugr.Hugr`, compiling a
module returns a `hugr.Hugr` instead.

---------

Co-authored-by: Mark Koch <[email protected]>
  • Loading branch information
aborgna-q and mark-koch authored Aug 21, 2024
1 parent dd702ce commit 536abf9
Show file tree
Hide file tree
Showing 42 changed files with 1,501 additions and 2,152 deletions.
8 changes: 4 additions & 4 deletions guppylang/checker/stmt_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ast
from collections.abc import Sequence

from guppylang.ast_util import AstVisitor, with_loc
from guppylang.ast_util import AstVisitor, with_loc, with_type
from guppylang.cfg.bb import BB, BBStatement
from guppylang.checker.core import Context, FieldAccess, Variable
from guppylang.checker.expr_checker import ExprChecker, ExprSynthesizer
Expand Down Expand Up @@ -53,7 +53,7 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr:
case ast.Name(id=x):
var = Variable(x, ty, lhs)
self.ctx.locals[x] = var
return with_loc(lhs, PlaceNode(place=var))
return with_loc(lhs, with_type(ty, PlaceNode(place=var)))

# The LHS could also be a field `expr.field`
case ast.Attribute(value=value, attr=attr):
Expand Down Expand Up @@ -93,7 +93,7 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr:
"Mutation of classical fields is not supported yet", lhs
)
place = FieldAccess(value.place, struct_ty.field_dict[attr], lhs)
return with_loc(lhs, PlaceNode(place=place))
return with_loc(lhs, with_type(ty, PlaceNode(place=place)))

# The only other thing we support right now are tuples
case ast.Tuple(elts=elts) as lhs:
Expand All @@ -109,7 +109,7 @@ def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> ast.expr:
self._check_assign(pat, el_ty, node)
for pat, el_ty in zip(elts, tys, strict=True)
]
return lhs
return with_type(ty, lhs)

# TODO: Python also supports assignments like `[a, b] = [1, 2]` or
# `a, *b = ...`. The former would require some runtime checks but
Expand Down
97 changes: 51 additions & 46 deletions guppylang/compiler/cfg_compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import functools
from typing import TYPE_CHECKING
from collections.abc import Sequence

from hugr import Wire, ops
from hugr import cfg as hc
from hugr import tys as ht
from hugr.dfg import DP, _DfBase
from hugr.node_port import ToNode

from guppylang.checker.cfg_checker import CheckedBB, CheckedCFG, Row, Signature
from guppylang.checker.core import Place, Variable
Expand All @@ -11,62 +17,68 @@
)
from guppylang.compiler.expr_compiler import ExprCompiler
from guppylang.compiler.stmt_compiler import StmtCompiler
from guppylang.hugr_builder.hugr import CFNode, Hugr, Node, OutPortV
from guppylang.tys.builtin import is_bool_type
from guppylang.tys.ty import SumType, row_to_type, type_to_row

if TYPE_CHECKING:
from collections.abc import Sequence


def compile_cfg(
cfg: CheckedCFG[Place], graph: Hugr, parent: Node, globals: CompiledGlobals
) -> None:
cfg: CheckedCFG[Place],
container: _DfBase[DP],
inputs: Sequence[Wire],
globals: CompiledGlobals,
) -> hc.Cfg:
"""Compiles a CFG to Hugr."""
insert_return_vars(cfg)

blocks: dict[CheckedBB[Place], CFNode] = {}
builder = container.add_cfg(*inputs)

blocks: dict[CheckedBB[Place], ToNode] = {}
for bb in cfg.bbs:
blocks[bb] = compile_bb(bb, graph, parent, bb == cfg.entry_bb, globals)
blocks[bb] = compile_bb(bb, builder, bb == cfg.entry_bb, globals)
for bb in cfg.bbs:
for succ in bb.successors:
graph.add_edge(blocks[bb].add_out_port(), blocks[succ].in_port(None))
for i, succ in enumerate(bb.successors):
builder.branch(blocks[bb][i], blocks[succ])

return builder


def compile_bb(
bb: CheckedBB[Place],
graph: Hugr,
parent: Node,
builder: hc.Cfg,
is_entry: bool,
globals: CompiledGlobals,
) -> CFNode:
"""Compiles a single basic block to Hugr."""
inputs = bb.sig.input_row if is_entry else sort_vars(bb.sig.input_row)
) -> ToNode:
"""Compiles a single basic block to Hugr, and returns the resulting block.
If the basic block is the output block, returns `None`.
"""
# The exit BB is completely empty
if len(bb.successors) == 0:
assert len(bb.statements) == 0
return graph.add_exit([v.ty for v in inputs], parent)
return builder.exit

# Otherwise, we use a regular `Block` node
block = graph.add_block(parent)
block: hc.Block
inputs: Sequence[Place]
if is_entry:
inputs = bb.sig.input_row
block = builder.add_entry()
else:
inputs = sort_vars(bb.sig.input_row)
block = builder.add_block(*(v.ty.to_hugr() for v in inputs))

# Add input node and compile the statements
inp = graph.add_input(output_tys=[v.ty for v in inputs], parent=block)
dfg = DFContainer(graph, block)
for i, v in enumerate(inputs):
dfg[v] = inp.out_port(i)
dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, dfg)
dfg = DFContainer(block)
for v, wire in zip(inputs, block.input_node, strict=True):
dfg[v] = wire
dfg = StmtCompiler(globals).compile_stmts(bb.statements, dfg)

# If we branch, we also have to compile the branch predicate
if len(bb.successors) > 1:
assert bb.branch_pred is not None
branch_port = ExprCompiler(graph, globals).compile(bb.branch_pred, dfg)
branch_port = ExprCompiler(globals).compile(bb.branch_pred, dfg)
else:
# Even if we don't branch, we still have to add a `Sum(())` predicates
branch_port = graph.add_tag(
variants=[[]], tag=0, inputs=[], parent=block
).out_port(0)
branch_port = dfg.builder.add_op(ops.Tag(0, ht.UnitSum(1)))

# Finally, we have to add the block output.
outputs: Sequence[Place]
Expand All @@ -87,7 +99,6 @@ def compile_bb(
# ordering on variables which puts linear variables at the end. The only
# exception are return vars which must be outputted in order.
branch_port = choose_vars_for_tuple_sum(
graph=graph,
unit_sum=branch_port,
output_vars=[
[
Expand All @@ -101,9 +112,7 @@ def compile_bb(
)
outputs = [v for v in first if v.ty.linear and not is_return_var(str(v))]

graph.add_output(
inputs=[branch_port] + [dfg[v] for v in sort_vars(outputs)], parent=block
)
block.set_block_outputs(branch_port, *(dfg[v] for v in sort_vars(outputs)))
return block


Expand All @@ -130,27 +139,23 @@ def insert_return_vars(cfg: CheckedCFG[Place]) -> None:


def choose_vars_for_tuple_sum(
graph: Hugr, unit_sum: OutPortV, output_vars: list[Row[Place]], dfg: DFContainer
) -> OutPortV:
unit_sum: Wire, output_vars: list[Row[Place]], dfg: DFContainer
) -> Wire:
"""Selects an output based on a TupleSum.
Given `unit_sum: Sum(*(), *(), ...)` and output variable rows `#s1, #s2, ...`,
constructs a TupleSum value of type `Sum(#s1, #s2, ...)`.
"""
assert isinstance(unit_sum.ty, SumType) or is_bool_type(unit_sum.ty)
assert len(output_vars) == (
len(unit_sum.ty.element_types) if isinstance(unit_sum.ty, SumType) else 2
)
assert all(not v.ty.linear for var_row in output_vars for v in var_row)
conditional = graph.add_conditional(cond_input=unit_sum, inputs=[], parent=dfg.node)
tys = [[v.ty for v in var_row] for var_row in output_vars]
for i, var_row in enumerate(output_vars):
case = graph.add_case(conditional)
graph.add_input(output_tys=[], parent=case)
inputs = [dfg[v] for v in var_row]
tag = graph.add_tag(variants=tys, tag=i, inputs=inputs, parent=case).out_port(0)
graph.add_output(inputs=[tag], parent=case)
return conditional.add_out_port(SumType([row_to_type(row) for row in tys]))
sum_type = SumType([row_to_type(row) for row in tys]).to_hugr()

with dfg.builder.add_conditional(unit_sum) as conditional:
for i, var_row in enumerate(output_vars):
with conditional.add_case(i) as case:
tag = case.add_op(ops.Tag(i, sum_type), *(dfg[v] for v in var_row))
case.set_outputs(tag)
return conditional


def compare_var(p1: Place, p2: Place) -> int:
Expand Down
52 changes: 31 additions & 21 deletions guppylang/compiler/core.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
from abc import ABC
from dataclasses import dataclass, field
from typing import cast

from hugr import Wire, ops
from hugr.dfg import DP, _DfBase

from guppylang.checker.core import FieldAccess, Place, PlaceId, Variable
from guppylang.definition.common import CompiledDef, DefId
from guppylang.error import InternalGuppyError
from guppylang.hugr_builder.hugr import DFContainingNode, Hugr, OutPortV
from guppylang.tys.ty import StructType

CompiledGlobals = dict[DefId, CompiledDef]
CompiledLocals = dict[PlaceId, OutPortV]
CompiledLocals = dict[PlaceId, Wire]


@dataclass
class DFContainer:
"""A dataflow graph under construction.
This class is passed through the entire compilation pipeline and stores the node
whose dataflow child-graph is currently being constructed as well as all live local
This class is passed through the entire compilation pipeline and stores a builder
for the dataflow child-graph currently being constructed as well as all live local
variables. Note that the variable map is mutated in-place and always reflects the
current compilation state.
"""

graph: Hugr
node: DFContainingNode
builder: _DfBase[ops.DfParentOp]
locals: CompiledLocals = field(default_factory=dict)

def __getitem__(self, place: Place) -> OutPortV:
"""Constructs a port for a local place in this DFG.
def __init__(
self, builder: _DfBase[DP], locals: CompiledLocals | None = None
) -> None:
generic_builder = cast(_DfBase[ops.DfParentOp], builder)
if locals is None:
locals = {}
self.builder = generic_builder
self.locals = locals

def __getitem__(self, place: Place) -> Wire:
"""Constructs a wire for a local place in this DFG.
Note that this mutates the Hugr since we might need to pack or unpack some
tuples to obtain a port for places that involve struct fields.
Expand All @@ -39,23 +50,24 @@ def __getitem__(self, place: Place) -> OutPortV:
if not isinstance(place.ty, StructType):
raise InternalGuppyError(f"Couldn't obtain a port for `{place}`")
children = [FieldAccess(place, field, None) for field in place.ty.fields]
child_ports = [self[child] for child in children]
port = self.graph.add_make_tuple(child_ports, self.node).out_port(0)
child_types = [child.ty.to_hugr() for child in children]
child_wires = [self[child] for child in children]
wire = self.builder.add_op(ops.MakeTuple(child_types), *child_wires)[0]
for child in children:
if child.ty.linear:
self.locals.pop(child.id)
self.locals[place.id] = port
return port
self.locals[place.id] = wire
return wire

def __setitem__(self, place: Place, port: OutPortV) -> None:
def __setitem__(self, place: Place, port: Wire) -> None:
# When assigning a struct value, we immediately unpack it recursively and only
# store the leaf wires.
is_return = isinstance(place, Variable) and is_return_var(place.name)
if isinstance(place.ty, StructType) and not is_return:
unpack = self.graph.add_unpack_tuple(port, self.node)
for field, field_port in zip(
place.ty.fields, unpack.out_ports, strict=True
):
unpack = self.builder.add_op(
ops.UnpackTuple([t.ty.to_hugr() for t in place.ty.fields]), port
)
for field, field_port in zip(place.ty.fields, unpack, strict=True):
self[FieldAccess(place, field, None)] = field_port
# If we had a previous wire assigned to this place, we need forget about it.
# Otherwise, we might use this old value when looking up the place later
Expand All @@ -66,17 +78,15 @@ def __setitem__(self, place: Place, port: OutPortV) -> None:
def __copy__(self) -> "DFContainer":
# Make a copy of the var map so that mutating the copy doesn't
# mutate our variable mapping
return DFContainer(self.graph, self.node, self.locals.copy())
return DFContainer(self.builder, self.locals.copy())


class CompilerBase(ABC):
"""Base class for the Guppy compiler."""

graph: Hugr
globals: CompiledGlobals

def __init__(self, graph: Hugr, globals: CompiledGlobals) -> None:
self.graph = graph
def __init__(self, globals: CompiledGlobals) -> None:
self.globals = globals


Expand Down
Loading

0 comments on commit 536abf9

Please sign in to comment.