Skip to content

Commit

Permalink
feat: Lower unpacking assignment of iterables to Hugr
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Dec 4, 2024
1 parent 31717af commit 5730408
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 28 deletions.
100 changes: 87 additions & 13 deletions guppylang/compiler/stmt_compiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
import functools
from collections.abc import Sequence

import hugr.tys as ht
from hugr import Wire, ops
from hugr.build.dfg import DfBase

Expand All @@ -14,8 +16,21 @@
)
from guppylang.compiler.expr_compiler import ExprCompiler
from guppylang.error import InternalGuppyError
from guppylang.nodes import CheckedNestedFunctionDef, PlaceNode, TupleUnpack
from guppylang.tys.ty import TupleType, Type
from guppylang.nodes import (
CheckedNestedFunctionDef,
IterableUnpack,
PlaceNode,
TupleUnpack,
)
from guppylang.std._internal.compiler.array import (
array_discard_empty,
array_new,
array_pop,
)
from guppylang.std._internal.compiler.prelude import build_unwrap
from guppylang.tys.builtin import get_element_type
from guppylang.tys.const import ConstValue
from guppylang.tys.ty import TupleType, Type, type_to_row


class StmtCompiler(CompilerBase, AstVisitor[None]):
Expand Down Expand Up @@ -49,27 +64,86 @@ def builder(self) -> DfBase[ops.DfParentOp]:
"""The Hugr dataflow graph builder."""
return self.dfg.builder

def _unpack_assign(self, lhs: ast.expr, port: Wire, node: ast.stmt) -> None:
@functools.singledispatchmethod
def _assign(self, lhs: ast.expr, port: Wire) -> None:
"""Updates the local DFG with assignments."""
if isinstance(lhs, PlaceNode):
self.dfg[lhs.place] = port
elif isinstance(lhs, TupleUnpack):
types = [get_type(e).to_hugr() for e in lhs.pattern.left]
unpack = self.builder.add_op(ops.UnpackTuple(types), port)
for pat, wire in zip(lhs.pattern.left, unpack, strict=True):
self._unpack_assign(pat, wire, node)
raise InternalGuppyError("Invalid assign pattern in compiler")

@_assign.register
def _assign_place(self, lhs: PlaceNode, port: Wire) -> None:
self.dfg[lhs.place] = port

@_assign.register
def _assign_tuple(self, lhs: TupleUnpack, port: Wire) -> None:
"""Handles assignment where the RHS is a tuple that should be unpacked."""
# Unpack the RHS tuple
left, starred, right = lhs.pattern.left, lhs.pattern.starred, lhs.pattern.right
types = [ty.to_hugr() for ty in type_to_row(get_type(lhs))]
unpack = self.builder.add_op(ops.UnpackTuple(types), port)
ports = list(unpack)

# Assign left and right
for pat, wire in zip(left, ports[: len(left)], strict=True):
self._assign(pat, wire)
if right:
for pat, wire in zip(right, ports[-len(right) :], strict=True):
self._assign(pat, wire)

# Starred assignments are collected into an array
if starred:
array_ty = get_type(starred)
starred_ports = (
ports[len(left) : -len(right)] if right else ports[len(left) :]
)
opt_ty = ht.Option(get_element_type(array_ty).to_hugr())
opts = [self.builder.add_op(ops.Tag(1, opt_ty), p) for p in starred_ports]
array = self.builder.add_op(array_new(opt_ty, len(opts)), *opts)
self._assign(starred, array)

@_assign.register
def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None:
"""Handles assignment where the RHS is an iterable that should be unpacked."""
# Given an assignment pattern `left, *starred, right`, collect the RHS into an
# array and pop from the left and right, leaving us with the starred array in
# the middle
assert isinstance(lhs.compr.length, ConstValue)
length = lhs.compr.length.value
assert isinstance(length, int)
opt_elt_ty = ht.Option(lhs.compr.elt_ty.to_hugr())

def pop(
array: Wire, length: int, pats: list[ast.expr], from_left: bool
) -> tuple[Wire, int]:
err = "Internal error: unpacking of iterable failed"
for pat in pats:
res = self.builder.add_op(
array_pop(opt_elt_ty, length, from_left), array
)
[elt_opt, array] = build_unwrap(self.builder, res, err)
[elt] = build_unwrap(self.builder, elt_opt, err)
self._assign(pat, elt)
length -= 1
return array, length

self.dfg[lhs.rhs_var.place] = port
array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr)
array, length = pop(array, length, lhs.pattern.left, True)
array, length = pop(array, length, lhs.pattern.right, False)
if lhs.pattern.starred:
self._assign(lhs.pattern.starred, array)
else:
raise InternalGuppyError("Invalid assign pattern in compiler")
assert length == 0
self.builder.add_op(array_discard_empty(opt_elt_ty), array)

def visit_Assign(self, node: ast.Assign) -> None:
[target] = node.targets
port = self.expr_compiler.compile(node.value, self.dfg)
self._unpack_assign(target, port, node)
self._assign(target, port)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
assert node.value is not None
port = self.expr_compiler.compile(node.value, self.dfg)
self._unpack_assign(node.target, port, node)
self._assign(node.target, port)

def visit_AugAssign(self, node: ast.AugAssign) -> None:
raise InternalGuppyError("Node should have been removed during type checking.")
Expand Down
20 changes: 20 additions & 0 deletions guppylang/std/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
)


def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> ops.ExtOp:
"""Returns an operation that pops an element from the left of an array."""
assert length > 0
length_arg = ht.BoundedNatArg(length)
arr_ty = array_type(elem_ty, length_arg)
popped_arr_ty = array_type(elem_ty, ht.BoundedNatArg(length - 1))
op = "pop_left" if from_left else "pop_right"
return _instantiate_array_op(
op, elem_ty, length_arg, [arr_ty], [ht.Option(elem_ty, popped_arr_ty)]
)


def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp:
"""Returns an operation that discards an array of length zero."""
arr_ty = array_type(elem_ty, ht.BoundedNatArg(0))
return hugr.std.PRELUDE.get_op("discard_empty").instantiate(
[ht.TypeTypeArg(elem_ty)], ht.FunctionType([arr_ty], [])
)


def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp:
"""Returns an operation that maps a function across an array."""
# TODO
Expand Down
23 changes: 8 additions & 15 deletions tests/integration/test_unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def main(qs: array[qubit, 3] @ owned) -> tuple[qubit, qubit, qubit]:
q1, q2, q3 = qs
return q1, q2, q3

# validate(module.compile())
module.check()
validate(module.compile())


def test_unpack_starred(validate):
Expand All @@ -32,8 +31,7 @@ def main(
[*qs] = qs
return q1, q2, q3, q4, q5, q6, qs

# validate(module.compile())
module.check()
validate(module.compile())


def test_unpack_starred_empty(validate):
Expand All @@ -45,8 +43,7 @@ def main(qs: array[qubit, 2] @ owned) -> tuple[qubit, array[qubit, 0], qubit]:
q1, *empty, q2 = qs
return q1, empty, q2

# validate(module.compile())
module.check()
validate(module.compile())


def test_unpack_big_iterable(validate):
Expand All @@ -59,8 +56,7 @@ def main(qs: array[qubit, 1000] @ owned) -> tuple[qubit, array[qubit, 998], qubi
q1, *qs, q2 = qs
return q1, qs, q2

# validate(module.compile())
module.check()
validate(module.compile())


def test_unpack_range(validate, run_int_fn):
Expand All @@ -71,9 +67,8 @@ def main() -> int:
[_, x, *_, y, _] = range(10)
return x + y

module.check()
# compiled = module.compile()
# validate(compiled)
compiled = module.compile()
validate(compiled)
# TODO: Enable execution test once array lowering is fully supported
# run_int_fn(compiled, expected=9)

Expand All @@ -86,8 +81,7 @@ def main() -> array[int, 2]:
x, *ys, z = 1, 2, 3, 4
return ys

# validate(module.compile())
module.check()
validate(module.compile())


def test_unpack_nested(validate, run_int_fn):
Expand All @@ -107,5 +101,4 @@ def main(
(x, [y, *z, _], *a), *b, c = xs
return x, y, z, a, b, c

# validate(module.compile())
module.check()
validate(module.compile())

0 comments on commit 5730408

Please sign in to comment.