diff --git a/guppylang/compiler/stmt_compiler.py b/guppylang/compiler/stmt_compiler.py index 1d070407..6f92854e 100644 --- a/guppylang/compiler/stmt_compiler.py +++ b/guppylang/compiler/stmt_compiler.py @@ -1,6 +1,8 @@ import ast +import functools from collections.abc import Sequence +import hugr.tys as ht from hugr import Wire, ops from hugr.build.dfg import DfBase @@ -14,8 +16,21 @@ ) from guppylang.compiler.expr_compiler import ExprCompiler from guppylang.error import InternalGuppyError -from guppylang.nodes import CheckedNestedFunctionDef, PlaceNode, TupleUnpack -from guppylang.tys.ty import TupleType, Type +from guppylang.nodes import ( + CheckedNestedFunctionDef, + IterableUnpack, + PlaceNode, + TupleUnpack, +) +from guppylang.std._internal.compiler.array import ( + array_discard_empty, + array_new, + array_pop, +) +from guppylang.std._internal.compiler.prelude import build_unwrap +from guppylang.tys.builtin import get_element_type +from guppylang.tys.const import ConstValue +from guppylang.tys.ty import TupleType, Type, type_to_row class StmtCompiler(CompilerBase, AstVisitor[None]): @@ -49,27 +64,86 @@ def builder(self) -> DfBase[ops.DfParentOp]: """The Hugr dataflow graph builder.""" return self.dfg.builder - def _unpack_assign(self, lhs: ast.expr, port: Wire, node: ast.stmt) -> None: + @functools.singledispatchmethod + def _assign(self, lhs: ast.expr, port: Wire) -> None: """Updates the local DFG with assignments.""" - if isinstance(lhs, PlaceNode): - self.dfg[lhs.place] = port - elif isinstance(lhs, TupleUnpack): - types = [get_type(e).to_hugr() for e in lhs.pattern.left] - unpack = self.builder.add_op(ops.UnpackTuple(types), port) - for pat, wire in zip(lhs.pattern.left, unpack, strict=True): - self._unpack_assign(pat, wire, node) + raise InternalGuppyError("Invalid assign pattern in compiler") + + @_assign.register + def _assign_place(self, lhs: PlaceNode, port: Wire) -> None: + self.dfg[lhs.place] = port + + @_assign.register + def _assign_tuple(self, lhs: TupleUnpack, port: Wire) -> None: + """Handles assignment where the RHS is a tuple that should be unpacked.""" + # Unpack the RHS tuple + left, starred, right = lhs.pattern.left, lhs.pattern.starred, lhs.pattern.right + types = [ty.to_hugr() for ty in type_to_row(get_type(lhs))] + unpack = self.builder.add_op(ops.UnpackTuple(types), port) + ports = list(unpack) + + # Assign left and right + for pat, wire in zip(left, ports[: len(left)], strict=True): + self._assign(pat, wire) + if right: + for pat, wire in zip(right, ports[-len(right) :], strict=True): + self._assign(pat, wire) + + # Starred assignments are collected into an array + if starred: + array_ty = get_type(starred) + starred_ports = ( + ports[len(left) : -len(right)] if right else ports[len(left) :] + ) + opt_ty = ht.Option(get_element_type(array_ty).to_hugr()) + opts = [self.builder.add_op(ops.Tag(1, opt_ty), p) for p in starred_ports] + array = self.builder.add_op(array_new(opt_ty, len(opts)), *opts) + self._assign(starred, array) + + @_assign.register + def _assign_iterable(self, lhs: IterableUnpack, port: Wire) -> None: + """Handles assignment where the RHS is an iterable that should be unpacked.""" + # Given an assignment pattern `left, *starred, right`, collect the RHS into an + # array and pop from the left and right, leaving us with the starred array in + # the middle + assert isinstance(lhs.compr.length, ConstValue) + length = lhs.compr.length.value + assert isinstance(length, int) + opt_elt_ty = ht.Option(lhs.compr.elt_ty.to_hugr()) + + def pop( + array: Wire, length: int, pats: list[ast.expr], from_left: bool + ) -> tuple[Wire, int]: + err = "Internal error: unpacking of iterable failed" + for pat in pats: + res = self.builder.add_op( + array_pop(opt_elt_ty, length, from_left), array + ) + [elt_opt, array] = build_unwrap(self.builder, res, err) + [elt] = build_unwrap(self.builder, elt_opt, err) + self._assign(pat, elt) + length -= 1 + return array, length + + self.dfg[lhs.rhs_var.place] = port + array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr) + array, length = pop(array, length, lhs.pattern.left, True) + array, length = pop(array, length, lhs.pattern.right, False) + if lhs.pattern.starred: + self._assign(lhs.pattern.starred, array) else: - raise InternalGuppyError("Invalid assign pattern in compiler") + assert length == 0 + self.builder.add_op(array_discard_empty(opt_elt_ty), array) def visit_Assign(self, node: ast.Assign) -> None: [target] = node.targets port = self.expr_compiler.compile(node.value, self.dfg) - self._unpack_assign(target, port, node) + self._assign(target, port) def visit_AnnAssign(self, node: ast.AnnAssign) -> None: assert node.value is not None port = self.expr_compiler.compile(node.value, self.dfg) - self._unpack_assign(node.target, port, node) + self._assign(node.target, port) def visit_AugAssign(self, node: ast.AugAssign) -> None: raise InternalGuppyError("Node should have been removed during type checking.") diff --git a/guppylang/std/_internal/compiler/array.py b/guppylang/std/_internal/compiler/array.py index b6d90b12..3af17705 100644 --- a/guppylang/std/_internal/compiler/array.py +++ b/guppylang/std/_internal/compiler/array.py @@ -72,6 +72,26 @@ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: ) +def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> ops.ExtOp: + """Returns an operation that pops an element from the left of an array.""" + assert length > 0 + length_arg = ht.BoundedNatArg(length) + arr_ty = array_type(elem_ty, length_arg) + popped_arr_ty = array_type(elem_ty, ht.BoundedNatArg(length - 1)) + op = "pop_left" if from_left else "pop_right" + return _instantiate_array_op( + op, elem_ty, length_arg, [arr_ty], [ht.Option(elem_ty, popped_arr_ty)] + ) + + +def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp: + """Returns an operation that discards an array of length zero.""" + arr_ty = array_type(elem_ty, ht.BoundedNatArg(0)) + return hugr.std.PRELUDE.get_op("discard_empty").instantiate( + [ht.TypeTypeArg(elem_ty)], ht.FunctionType([arr_ty], []) + ) + + def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp: """Returns an operation that maps a function across an array.""" # TODO diff --git a/tests/integration/test_unpack.py b/tests/integration/test_unpack.py index c9cbe991..f399f22d 100644 --- a/tests/integration/test_unpack.py +++ b/tests/integration/test_unpack.py @@ -14,8 +14,7 @@ def main(qs: array[qubit, 3] @ owned) -> tuple[qubit, qubit, qubit]: q1, q2, q3 = qs return q1, q2, q3 - # validate(module.compile()) - module.check() + validate(module.compile()) def test_unpack_starred(validate): @@ -32,8 +31,7 @@ def main( [*qs] = qs return q1, q2, q3, q4, q5, q6, qs - # validate(module.compile()) - module.check() + validate(module.compile()) def test_unpack_starred_empty(validate): @@ -45,8 +43,7 @@ def main(qs: array[qubit, 2] @ owned) -> tuple[qubit, array[qubit, 0], qubit]: q1, *empty, q2 = qs return q1, empty, q2 - # validate(module.compile()) - module.check() + validate(module.compile()) def test_unpack_big_iterable(validate): @@ -59,8 +56,7 @@ def main(qs: array[qubit, 1000] @ owned) -> tuple[qubit, array[qubit, 998], qubi q1, *qs, q2 = qs return q1, qs, q2 - # validate(module.compile()) - module.check() + validate(module.compile()) def test_unpack_range(validate, run_int_fn): @@ -71,9 +67,8 @@ def main() -> int: [_, x, *_, y, _] = range(10) return x + y - module.check() - # compiled = module.compile() - # validate(compiled) + compiled = module.compile() + validate(compiled) # TODO: Enable execution test once array lowering is fully supported # run_int_fn(compiled, expected=9) @@ -86,8 +81,7 @@ def main() -> array[int, 2]: x, *ys, z = 1, 2, 3, 4 return ys - # validate(module.compile()) - module.check() + validate(module.compile()) def test_unpack_nested(validate, run_int_fn): @@ -107,5 +101,4 @@ def main( (x, [y, *z, _], *a), *b, c = xs return x, y, z, a, b, c - # validate(module.compile()) - module.check() + validate(module.compile())