diff --git a/guppylang/std/quantum.py b/guppylang/std/quantum.py index d937da61..9925648a 100644 --- a/guppylang/std/quantum.py +++ b/guppylang/std/quantum.py @@ -13,7 +13,7 @@ ) from guppylang.std._internal.util import quantum_op from guppylang.std.angles import angle -from guppylang.std.builtins import owned +from guppylang.std.builtins import array, owned from guppylang.std.option import Option @@ -144,3 +144,21 @@ def measure(q: qubit @ owned) -> bool: ... @guppy.hugr_op(quantum_op("Reset")) @no_type_check def reset(q: qubit) -> None: ... + + +N = guppy.nat_var("N") + + +@guppy +@no_type_check +def measure_array(qubits: array[qubit, N] @ owned) -> array[bool, N]: + """Measure an array of qubits, returning an array of bools.""" + return array(measure(q) for q in qubits) + + +@guppy +@no_type_check +def discard_array(qubits: array[qubit, N] @ owned) -> None: + """Discard an array of qubits.""" + for q in qubits: + discard(q) diff --git a/tests/integration/test_quantum.py b/tests/integration/test_quantum.py index b94efef0..17e77ace 100644 --- a/tests/integration/test_quantum.py +++ b/tests/integration/test_quantum.py @@ -6,9 +6,16 @@ from guppylang.module import GuppyModule from guppylang.std.angles import angle -from guppylang.std.builtins import owned - -from guppylang.std.quantum import discard, measure, qubit, maybe_qubit +from guppylang.std.builtins import owned, array + +from guppylang.std.quantum import ( + discard, + measure, + qubit, + maybe_qubit, + measure_array, + discard_array, +) from guppylang.std.quantum_functional import ( cx, cy, @@ -43,7 +50,9 @@ def compile_quantum_guppy(fn) -> ModulePointer: ), "`@compile_quantum_guppy` does not support extra arguments." module = GuppyModule("module") - module.load(angle, qubit, discard, measure, maybe_qubit) + module.load( + angle, qubit, discard, measure, measure_array, maybe_qubit, discard_array + ) module.load_all(quantum_functional) guppylang.decorator.guppy(module)(fn) return module.compile() @@ -123,3 +132,25 @@ def test( q2 = rz(q2, a3) q1, q2 = crz(q1, q2, a3) return (q1, q2) + + +def test_measure_array(validate): + """Build and measure array.""" + + @compile_quantum_guppy + def test() -> array[bool, 10]: + qs = array(qubit() for _ in range(10)) + return measure_array(qs) + + validate(test) + + +def test_discard_array(validate): + """Build and discard array.""" + + @compile_quantum_guppy + def test() -> None: + qs = array(qubit() for _ in range(10)) + discard_array(qs) + + validate(test)