diff --git a/guppylang/prelude/_internal/compiler/array.py b/guppylang/prelude/_internal/compiler/array.py index 298828ea..5c571d4e 100644 --- a/guppylang/prelude/_internal/compiler/array.py +++ b/guppylang/prelude/_internal/compiler/array.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import hugr.std from hugr import Wire, ops from hugr import tys as ht @@ -12,229 +10,193 @@ from guppylang.definition.value import CallReturnWires from guppylang.error import InternalGuppyError from guppylang.prelude._internal.compiler.arithmetic import convert_itousize -from guppylang.prelude._internal.compiler.prelude import build_error, build_panic +from guppylang.prelude._internal.compiler.prelude import ( + build_unwrap, + build_unwrap_left, + build_unwrap_right, +) from guppylang.tys.arg import ConstArg, TypeArg from guppylang.tys.const import ConstValue -if TYPE_CHECKING: - from hugr.build.dfg import DfBase +# ------------------------------------------------------ +# --------------- std.array operations ----------------- +# ------------------------------------------------------ + + +def _instantiate_array_op( + name: str, elem_ty: ht.Type, length: int, inp: list[ht.Type], out: list[ht.Type] +) -> ops.ExtOp: + return hugr.std.PRELUDE.get_op(name).instantiate( + [ht.BoundedNatArg(length), ht.TypeTypeArg(elem_ty)], ht.FunctionType(inp, out) + ) -def array_type(length: int, elem_ty: ht.Type) -> ht.ExtType: +def array_type(elem_ty: ht.Type, length: int) -> ht.ExtType: """Returns the hugr type of a fixed length array.""" length_arg = ht.BoundedNatArg(length) elem_arg = ht.TypeTypeArg(elem_ty) return hugr.std.PRELUDE.types["array"].instantiate([length_arg, elem_arg]) +def array_new(elem_ty: ht.Type, length: int) -> ops.ExtOp: + """Returns an operation that creates a new fixed length array.""" + arr_ty = array_type(elem_ty, length) + return _instantiate_array_op( + "new_array", elem_ty, length, [elem_ty] * length, [arr_ty] + ) + + +def array_get(elem_ty: ht.Type, length: int) -> ops.ExtOp: + """Returns an array `get` operation.""" + assert elem_ty.type_bound() == ht.TypeBound.Copyable + arr_ty = array_type(elem_ty, length) + return _instantiate_array_op( + "get", elem_ty, length, [arr_ty, ht.USize()], [ht.Option(elem_ty)] + ) + + +def array_set(elem_ty: ht.Type, length: int) -> ops.ExtOp: + """Returns an array `set` operation.""" + arr_ty = array_type(elem_ty, length) + return _instantiate_array_op( + "set", + elem_ty, + length, + [arr_ty, ht.USize(), elem_ty], + [ht.Either([elem_ty, arr_ty], [elem_ty, arr_ty])], + ) + + # ------------------------------------------------------ # --------- Custom compilers for non-native ops -------- # ------------------------------------------------------ -class NewArrayCompiler(CustomCallCompiler): +class ArrayCompiler(CustomCallCompiler): + """Base class for custom array op compilers.""" + + @property + def elem_ty(self) -> ht.Type: + """The element type for the array op that is being compiled.""" + match self.type_args: + case [TypeArg(ty=elem_ty), _]: + return elem_ty.to_hugr() + case _: + raise InternalGuppyError("Invalid array type args") + + @property + def length(self) -> int: + """The length for the array op that is being compiled.""" + match self.type_args: + case [_, ConstArg(ConstValue(value=int(length)))]: + return length + case _: + raise InternalGuppyError("Invalid array type args") + + +class NewArrayCompiler(ArrayCompiler): """Compiler for the `array.__new__` function.""" + def build_classical_array(self, elems: list[Wire]) -> Wire: + """Lowers a call to `array.__new__` for classical arrays.""" + return self.builder.add_op(array_new(self.elem_ty, len(elems)), *elems) + + def build_linear_array(self, elems: list[Wire]) -> Wire: + """Lowers a call to `array.__new__` for linear arrays.""" + elem_opt_ty = ht.Option(self.elem_ty) + elem_opts = [ + self.builder.add_op(ops.Tag(1, elem_opt_ty), elem) for elem in elems + ] + return self.builder.add_op(array_new(elem_opt_ty, len(elems)), *elem_opts) + def compile(self, args: list[Wire]) -> list[Wire]: - match self.type_args: - case [TypeArg(ty=elem_ty), ConstArg(ConstValue(value=int(length)))]: - op = new_array(length, elem_ty.to_hugr()) - return [self.builder.add_op(op, *args)] - case type_args: - raise InternalGuppyError(f"Invalid array type args: {type_args}") + if self.elem_ty.type_bound() == ht.TypeBound.Any: + return [self.build_linear_array(args)] + else: + return [self.build_classical_array(args)] -class ArrayGetitemCompiler(CustomCallCompiler): +class ArrayGetitemCompiler(ArrayCompiler): """Compiler for the `array.__getitem__` function.""" - def build_classical_getitem( - self, - array: Wire, - array_ty: ht.Type, - idx: Wire, - idx_ty: ht.Type, - elem_ty: ht.Type, - ) -> CallReturnWires: + def build_classical_getitem(self, array: Wire, idx: Wire) -> CallReturnWires: """Lowers a call to `array.__getitem__` for classical arrays.""" - length = self.type_args[1].to_hugr() - elem = build_array_get( - self.builder, array, array_ty, idx, idx_ty, elem_ty, length - ) + idx = self.builder.add_op(convert_itousize(), idx) + result = self.builder.add_op(array_get(self.elem_ty, self.length), array, idx) + elem = build_unwrap(self.builder, result, "Array index out of bounds") return CallReturnWires(regular_returns=[elem], inout_returns=[array]) - def build_linear_getitem( - self, - array: Wire, - array_ty: ht.Type, - idx: Wire, - idx_ty: ht.Type, - elem_ty: ht.Type, - ) -> CallReturnWires: + def build_linear_getitem(self, array: Wire, idx: Wire) -> CallReturnWires: """Lowers a call to `array.__getitem__` for linear arrays.""" # Swap out the element at the given index with `None`. The `to_hugr` # implementation of the array type ensures that linear element types are turned # into optionals. - elem_opt_ty = ht.Sum([[elem_ty], []]) - none = self.builder.add_op(ops.Tag(1, elem_opt_ty)) - length = self.type_args[1].to_hugr() - array, elem_opt = build_array_set( - self.builder, - array, - array_ty, - idx, - idx_ty, - none, - elem_opt_ty, - length, + elem_opt_ty = ht.Option(self.elem_ty) + none = self.builder.add_op(ops.Tag(0, elem_opt_ty)) + idx = self.builder.add_op(convert_itousize(), idx) + result = self.builder.add_op( + array_set(elem_opt_ty, self.length), array, idx, none + ) + elem_opt, array = build_unwrap_right( + self.builder, result, "Array index out of bounds" + ) + elem = build_unwrap( + self.builder, elem_opt, "Linear array element has already been used" ) - # Make sure that the element we got out is not None - conditional = self.builder.add_conditional(elem_opt) - with conditional.add_case(0) as case: - case.set_outputs(*case.inputs()) - with conditional.add_case(1) as case: - error = build_error(case, 1, "Linear array element has already been used") - case.set_outputs(*build_panic(case, [], [elem_ty], error)) - return CallReturnWires(regular_returns=[conditional], inout_returns=[array]) + return CallReturnWires(regular_returns=[elem], inout_returns=[array]) def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: [array, idx] = args - [array_ty, idx_ty] = self.ty.input - [elem_ty, *_] = self.ty.output - if elem_ty.type_bound() == ht.TypeBound.Any: - return self.build_linear_getitem(array, array_ty, idx, idx_ty, elem_ty) + if self.elem_ty.type_bound() == ht.TypeBound.Any: + return self.build_linear_getitem(array, idx) else: - return self.build_classical_getitem(array, array_ty, idx, idx_ty, elem_ty) + return self.build_classical_getitem(array, idx) def compile(self, args: list[Wire]) -> list[Wire]: raise InternalGuppyError("Call compile_with_inouts instead") -class ArraySetitemCompiler(CustomCallCompiler): +class ArraySetitemCompiler(ArrayCompiler): """Compiler for the `array.__setitem__` function.""" def build_classical_setitem( - self, - array: Wire, - array_ty: ht.Type, - idx: Wire, - idx_ty: ht.Type, - elem: Wire, - elem_ty: ht.Type, - length: ht.TypeArg, + self, array: Wire, idx: Wire, elem: Wire ) -> CallReturnWires: """Lowers a call to `array.__setitem__` for classical arrays.""" - array, _ = build_array_set( - self.builder, array, array_ty, idx, idx_ty, elem, elem_ty, length + idx = self.builder.add_op(convert_itousize(), idx) + result = self.builder.add_op( + array_set(self.elem_ty, self.length), array, idx, elem ) + # Unwrap the result, but we don't have to hold onto the returned old value + _, array = build_unwrap_right(self.builder, result, "Array index out of bounds") return CallReturnWires(regular_returns=[], inout_returns=[array]) def build_linear_setitem( - self, - array: Wire, - array_ty: ht.Type, - idx: Wire, - idx_ty: ht.Type, - elem: Wire, - elem_ty: ht.Type, - length: ht.TypeArg, + self, array: Wire, idx: Wire, elem: Wire ) -> CallReturnWires: """Lowers a call to `array.__setitem__` for linear arrays.""" # Embed the element into an optional - elem_opt_ty = ht.Sum([[elem_ty], []]) - elem = self.builder.add_op(ops.Tag(0, elem_opt_ty), elem) - array, old_elem = build_array_set( - self.builder, array, array_ty, idx, idx_ty, elem, elem_opt_ty, length + elem_opt_ty = ht.Option(self.elem_ty) + elem = self.builder.add_op(ops.Tag(1, elem_opt_ty), elem) + idx = self.builder.add_op(convert_itousize(), idx) + result = self.builder.add_op( + array_set(elem_opt_ty, self.length), array, idx, elem + ) + old_elem_opt, array = build_unwrap_right( + self.builder, result, "Array index out of bounds" ) # Check that the old element was `None` - conditional = self.builder.add_conditional(old_elem) - with conditional.add_case(0) as case: - error = build_error(case, 1, "Linear array element has not been used") - build_panic(case, [elem_ty], [], error, *case.inputs()) - case.set_outputs() - with conditional.add_case(1) as case: - case.set_outputs() + build_unwrap_left( + self.builder, old_elem_opt, "Linear array element has not been used" + ) return CallReturnWires(regular_returns=[], inout_returns=[array]) def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: [array, idx, elem] = args - [array_ty, idx_ty, elem_ty] = self.ty.input - length = self.type_args[1].to_hugr() - if elem_ty.type_bound() == ht.TypeBound.Any: - return self.build_linear_setitem( - array, array_ty, idx, idx_ty, elem, elem_ty, length - ) + if self.elem_ty.type_bound() == ht.TypeBound.Any: + return self.build_linear_setitem(array, idx, elem) else: - return self.build_classical_setitem( - array, array_ty, idx, idx_ty, elem, elem_ty, length - ) + return self.build_classical_setitem(array, idx, elem) def compile(self, args: list[Wire]) -> list[Wire]: raise InternalGuppyError("Call compile_with_inouts instead") - - -def build_array_set( - builder: DfBase[ops.DfParentOp], - array: Wire, - array_ty: ht.Type, - idx: Wire, - idx_ty: ht.Type, - elem: Wire, - elem_ty: ht.Type, - length: ht.TypeArg, -) -> tuple[Wire, Wire]: - """Builds an array set operation, returning the original element.""" - sig = ht.FunctionType( - [array_ty, ht.USize(), elem_ty], - [ht.Sum([[elem_ty, array_ty], [elem_ty, array_ty]])], - ) - if idx_ty != ht.USize(): - idx = builder.add_op(convert_itousize(), idx) - op = ops.ExtOp( - hugr.std.PRELUDE.get_op("set"), sig, [length, ht.TypeTypeArg(elem_ty)] - ) - [result] = builder.add_op(op, array, idx, elem) - conditional = builder.add_conditional(result) - with conditional.add_case(0) as case: - error = build_error(case, 1, "array set index out of bounds") - case.set_outputs( - *build_panic( - case, [elem_ty, array_ty], [elem_ty, array_ty], error, *case.inputs() - ) - ) - with conditional.add_case(1) as case: - case.set_outputs(*case.inputs()) - [elem, array] = conditional - return (array, elem) - - -def build_array_get( - builder: DfBase[ops.DfParentOp], - array: Wire, - array_ty: ht.Type, - idx: Wire, - idx_ty: ht.Type, - elem_ty: ht.Type, - length: ht.TypeArg, -) -> Wire: - """Builds an array get operation, returning the original element.""" - sig = ht.FunctionType([array_ty, ht.USize()], [ht.Sum([[], [elem_ty]])]) - op = ops.ExtOp( - hugr.std.PRELUDE.get_op("get"), sig, [length, ht.TypeTypeArg(elem_ty)] - ) - if idx_ty != ht.USize(): - idx = builder.add_op(convert_itousize(), idx) - [result] = builder.add_op(op, array, idx) - conditional = builder.add_conditional(result) - with conditional.add_case(0) as case: - error = build_error(case, 1, "array get index out of bounds") - case.set_outputs(*build_panic(case, [], [elem_ty], error)) - with conditional.add_case(1) as case: - case.set_outputs(*case.inputs()) - return conditional - - -def new_array(length: int, elem_ty: ht.Type) -> ops.ExtOp: - """Returns an operation that creates a new fixed length array.""" - op_def = hugr.std.PRELUDE.get_op("new_array") - sig = ht.FunctionType([elem_ty] * length, [array_type(length, elem_ty)]) - return ops.ExtOp(op_def, sig, [ht.BoundedNatArg(length), ht.TypeTypeArg(elem_ty)]) diff --git a/guppylang/prelude/_internal/compiler/list.py b/guppylang/prelude/_internal/compiler/list.py index c72f4721..72872d67 100644 --- a/guppylang/prelude/_internal/compiler/list.py +++ b/guppylang/prelude/_internal/compiler/list.py @@ -19,9 +19,8 @@ convert_itousize, ) from guppylang.prelude._internal.compiler.prelude import ( - build_error, - build_panic, build_unwrap, + build_unwrap_left, build_unwrap_right, ) from guppylang.tys.arg import TypeArg @@ -191,18 +190,9 @@ def build_linear_setitem( self.builder, result, "List index out of bounds" ) # Check that the old element was `None` - conditional = self.builder.add_conditional(old_elem_opt, list_wire) - with conditional.add_case(0) as case: - case.set_outputs(*case.inputs()) - with conditional.add_case(1) as case: - # Note: This case can only happen if users manually call `xs.__setitem__` - # since regular indexing `xs[i]` is only allowed in inout position. An error - # in that situation would be a compiler bug! - old_elem, list_wire = case.inputs() - error = build_error(case, 1, "Linear list element has not been used") - build_panic(case, [elem_ty], [], error, old_elem) - case.set_outputs(list_wire) - (list_wire,) = conditional.outputs() + build_unwrap_left( + self.builder, old_elem_opt, "Linear list element has not been used" + ) return CallReturnWires(regular_returns=[], inout_returns=[list_wire]) def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: diff --git a/guppylang/prelude/_internal/compiler/prelude.py b/guppylang/prelude/_internal/compiler/prelude.py index ebcebf2c..e5dcfbba 100644 --- a/guppylang/prelude/_internal/compiler/prelude.py +++ b/guppylang/prelude/_internal/compiler/prelude.py @@ -76,6 +76,10 @@ def build_error(builder: DfBase[ops.Case], signal: int, msg: str) -> Wire: return builder.load(builder.add_const(val)) +# TODO: Common up build_unwrap_right and build_unwrap_left below once +# https://github.com/CQCL/hugr/issues/1596 is fixed + + def build_unwrap_right( builder: DfBase[ops.DfParentOp], either: Wire, error_msg: str, error_signal: int = 1 ) -> Node: @@ -94,6 +98,24 @@ def build_unwrap_right( return conditional.to_node() +def build_unwrap_left( + builder: DfBase[ops.DfParentOp], either: Wire, error_msg: str, error_signal: int = 1 +) -> Node: + """Unwraps the left value from a `hugr.tys.Either` value, panicking with the given + message if the result is right. + """ + conditional = builder.add_conditional(either) + result_ty = builder.hugr.port_type(either.out_port()) + assert isinstance(result_ty, ht.Sum) + [left_tys, right_tys] = result_ty.variant_rows + with conditional.add_case(0) as case: + case.set_outputs(*case.inputs()) + with conditional.add_case(1) as case: + error = build_error(case, error_signal, error_msg) + case.set_outputs(*build_panic(case, right_tys, left_tys, error, *case.inputs())) + return conditional.to_node() + + def build_unwrap( builder: DfBase[ops.DfParentOp], result: Wire, error_msg: str, error_signal: int = 1 ) -> Node: diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 1bf305a0..1634ec14 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -132,11 +132,9 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: # Linear elements are turned into an optional to enable unsafe indexing. # See `ArrayGetitemCompiler` for details. - elem_ty: ht.Type - if ty_arg.ty.linear: - elem_ty = ht.Sum([[ty_arg.ty.to_hugr()], []]) - else: - elem_ty = ty_arg.ty.to_hugr() + elem_ty = ( + ht.Option(ty_arg.ty.to_hugr()) if ty_arg.ty.linear else ty_arg.ty.to_hugr() + ) array = hugr.std.PRELUDE.get_type("array") return array.instantiate([len_arg.to_hugr(), ht.TypeTypeArg(elem_ty)]) diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index cedbb1c8..d71f5b2a 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -79,6 +79,18 @@ def main(ys: array[int, 0]) -> array[array[int, 0], 2]: validate(main) +def test_return_linear_array(validate): + module = GuppyModule("test") + module.load(qubit) + + @guppy(module) + def foo() -> array[qubit, 2]: + a = array(qubit(), qubit()) + return a + + validate(module.compile()) + + def test_subscript_drop_rest(validate): module = GuppyModule("test") module.load_all(quantum)