Skip to content

Commit

Permalink
feat: Compile lists and comprehensions (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch authored Jan 16, 2024
1 parent e22fde7 commit 270ea9e
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 9 deletions.
2 changes: 1 addition & 1 deletion guppy/compiler/cfg_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def compile_bb(
for (i, v) in enumerate(inputs)
},
)
dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, bb, dfg)
dfg = StmtCompiler(graph, globals).compile_stmts(bb.statements, dfg)

# If we branch, we also have to compile the branch predicate
if len(bb.successors) > 1:
Expand Down
164 changes: 160 additions & 4 deletions guppy/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import ast
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any

from guppy.ast_util import AstVisitor, get_type
from guppy.compiler.core import CompiledFunction, CompilerBase, DFContainer
from guppy.ast_util import AstVisitor, get_type, with_loc, with_type
from guppy.cfg.builder import tmp_vars
from guppy.compiler.core import (
CompiledFunction,
CompilerBase,
DFContainer,
PortVariable,
)
from guppy.error import GuppyError, InternalGuppyError
from guppy.gtypes import (
BoolType,
Expand All @@ -14,8 +22,16 @@
type_to_row,
)
from guppy.hugr import ops, val
from guppy.hugr.hugr import OutPortV
from guppy.nodes import GlobalCall, GlobalName, LocalCall, LocalName, TypeApply
from guppy.hugr.hugr import DFContainingNode, OutPortV, VNode
from guppy.nodes import (
DesugaredGenerator,
DesugaredListComp,
GlobalCall,
GlobalName,
LocalCall,
LocalName,
TypeApply,
)


class ExprCompiler(CompilerBase, AstVisitor[OutPortV]):
Expand All @@ -39,6 +55,85 @@ def compile_row(self, expr: ast.expr, dfg: DFContainer) -> list[OutPortV]:
"""
return [self.compile(e, dfg) for e in expr_to_row(expr)]

@contextmanager
def _new_dfcontainer(
self, inputs: list[ast.Name], node: DFContainingNode
) -> Iterator[None]:
"""Context manager to build a graph inside a new `DFContainer`.
Automatically updates `self.dfg` and makes the inputs available.
"""
old = self.dfg
inp = self.graph.add_input(parent=node)
# Check that the input names are unique
assert len({inp.id for inp in inputs}) == len(inputs), "Inputs are not unique"
new_locals = {
name.id: PortVariable(name.id, inp.add_out_port(get_type(name)), name, None)
for name in inputs
}
self.dfg = DFContainer(node, self.dfg.locals | new_locals)
with self.graph.parent(node):
yield
self.dfg = old

@contextmanager
def _new_loop(
self,
loop_vars: list[ast.Name],
branch: ast.Name,
parent: DFContainingNode | None = None,
) -> Iterator[None]:
"""Context manager to build a graph inside a new `TailLoop` node.
Automatically adds the `Output` node to the loop body once the context manager
exits.
"""
loop = self.graph.add_tail_loop(
[self.visit(name) for name in loop_vars], parent
)
with self._new_dfcontainer(loop_vars, loop):
yield
# Output the branch predicate and the inputs for the next iteration
self.graph.add_output(
# Note that we have to do fresh calls to `self.visit` here since we're
# in a new context
[self.visit(branch), *(self.visit(name) for name in loop_vars)]
)
# Update the DFG with the outputs from the loop
for name in loop_vars:
self.dfg[name.id].port = loop.add_out_port(get_type(name))

@contextmanager
def _new_case(
self, inputs: list[ast.Name], outputs: list[ast.Name], cond_node: VNode
) -> Iterator[None]:
"""Context manager to build a graph inside a new `Case` node.
Automatically adds the `Output` node once the context manager exits.
"""
with self._new_dfcontainer(inputs, self.graph.add_case(cond_node)):
yield
self.graph.add_output([self.visit(name) for name in outputs])

@contextmanager
def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]:
"""Context manager to build a graph inside the `true` case of a `Conditional`
In the `false` case, the inputs are outputted as is.
"""
cond_node = self.graph.add_conditional(
self.visit(cond), [self.visit(inp) for inp in inputs]
)
# If the condition is false, output the inputs as is
with self._new_case(inputs, inputs, cond_node):
pass
# If the condition is true, we enter the `with` block
with self._new_case(inputs, inputs, cond_node):
yield
# Update the DFG with the outputs from the Conditional node
for name in inputs:
self.dfg[name.id].port = cond_node.add_out_port(get_type(name))

def visit_Constant(self, node: ast.Constant) -> OutPortV:
if value := python_value_to_hugr(node.value):
const = self.graph.add_constant(value, get_type(node)).out_port(0)
Expand All @@ -59,6 +154,12 @@ def visit_Tuple(self, node: ast.Tuple) -> OutPortV:
inputs=[self.visit(e) for e in node.elts]
).out_port(0)

def visit_List(self, node: ast.List) -> OutPortV:
# Note that this is a list literal (i.e. `[e1, e2, ...]`), not a comprehension
return self.graph.add_node(
ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts]
).add_out_port(get_type(node))

def _pack_returns(self, returns: list[OutPortV]) -> OutPortV:
"""Groups function return values into a tuple"""
if len(returns) != 1:
Expand Down Expand Up @@ -118,6 +219,61 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV:

raise InternalGuppyError("Node should have been removed during type checking.")

def visit_DesugaredListComp(self, node: DesugaredListComp) -> OutPortV:
from guppy.compiler.stmt_compiler import StmtCompiler

compiler = StmtCompiler(self.graph, self.globals)

# Make up a name for the list under construction and bind it to an empty list
list_ty = get_type(node)
list_name = with_type(list_ty, with_loc(node, LocalName(id=next(tmp_vars))))
empty_list = self.graph.add_node(ops.DummyOp(name="MakeList"))
self.dfg[list_name.id] = PortVariable(
list_name.id, empty_list.add_out_port(list_ty), node, None
)

def compile_generators(elt: ast.expr, gens: list[DesugaredGenerator]) -> None:
"""Helper function to generate nested TailLoop nodes for generators"""
# If there are no more generators left, just append the element to the list
if not gens:
list_port, elt_port = self.visit(list_name), self.visit(elt)
push = self.graph.add_node(
ops.DummyOp(name="Push"), inputs=[list_port, elt_port]
)
self.dfg[list_name.id].port = push.add_out_port(list_port.ty)
return

# Otherwise, compile the first iterator and construct a TailLoop
gen, *gens = gens
compiler.compile_stmts([gen.iter_assign], self.dfg)
inputs = [gen.iter, list_name]
with self._new_loop(inputs, gen.hasnext):
# If there is a next element, compile it and continue with the next
# generator
compiler.compile_stmts([gen.hasnext_assign], self.dfg)
with self._if_true(gen.hasnext, inputs):

def compile_ifs(ifs: list[ast.expr]) -> None:
"""Helper function to compile a series of if-guards into nested
Conditional nodes."""
if ifs:
if_expr, *ifs = ifs
# If the condition is true, continue with the next one
with self._if_true(if_expr, inputs):
compile_ifs(ifs)
else:
# If there are no guards left, compile the next generator
compile_generators(elt, gens)

compiler.compile_stmts([gen.next_assign], self.dfg)
compile_ifs(gen.ifs)

# After the loop is done, we have to finalize the iterator
self.visit(gen.iterend)

compile_generators(node.elt, node.generators)
return self.visit(list_name)

def visit_BinOp(self, node: ast.BinOp) -> OutPortV:
raise InternalGuppyError("Node should have been removed during type checking.")

Expand Down
4 changes: 0 additions & 4 deletions guppy/compiler/stmt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections.abc import Sequence

from guppy.ast_util import AstVisitor
from guppy.checker.cfg_checker import CheckedBB
from guppy.compiler.core import (
CompiledGlobals,
CompilerBase,
Expand All @@ -22,7 +21,6 @@ class StmtCompiler(CompilerBase, AstVisitor[None]):

expr_compiler: ExprCompiler

bb: CheckedBB
dfg: DFContainer

def __init__(self, graph: Hugr, globals: CompiledGlobals):
Expand All @@ -32,15 +30,13 @@ def __init__(self, graph: Hugr, globals: CompiledGlobals):
def compile_stmts(
self,
stmts: Sequence[ast.stmt],
bb: CheckedBB,
dfg: DFContainer,
) -> DFContainer:
"""Compiles a list of basic statements into a dataflow node.
Note that the `dfg` is mutated in-place. After compilation, the DFG will also
contain all variables that are assigned in the given list of statements.
"""
self.bb = bb
self.dfg = dfg
for s in stmts:
self.visit(s)
Expand Down
20 changes: 20 additions & 0 deletions guppy/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def add_output(
parent: Node | None = None,
) -> VNode:
"""Adds an `Output` node to the graph."""
parent = parent or self._default_parent
node = self.add_node(ops.Output(), input_tys, [], parent, inputs)
if isinstance(parent, DFContainingNode):
parent.output_child = node
Expand Down Expand Up @@ -706,6 +707,25 @@ def insert_order_edges(self) -> "Hugr":
elif isinstance(n.op, ops.LoadConstant):
assert n.parent.input_child is not None
self.add_order_edge(n.parent.input_child, n)

# Also add order edges for non-local edges
for src, tgt in list(self.edges()):
# Exclude CF and constant edges
if isinstance(src, OutPortCF) or isinstance(
src.node.op, ops.FuncDecl | ops.FuncDefn | ops.Const
):
continue

if src.node.parent != tgt.node.parent:
# Walk up the hierarchy from the tgt until we hit a node at the same
# level as src
node = tgt.node
while node.parent != src.node.parent:
if node.parent is None:
raise ValueError("Invalid non-local edge!")
node = node.parent
# Add order edge to make sure that the src is executed first
self.add_order_edge(src.node, node)
return self

def to_raw(self) -> raw.RawHugr:
Expand Down

0 comments on commit 270ea9e

Please sign in to comment.