Skip to content

Commit

Permalink
feat: add qubit discard/measure methods (#580)
Browse files Browse the repository at this point in the history
Extends the fact that allocation is tied to the `qubit` type by
implementing measure/discard methods.

Note this includes a drive-by `measure_reset` operation which avoids
resetting the qubit and just flip, might not be something to add in this
PR.
  • Loading branch information
ss2165 authored Oct 23, 2024
1 parent 1d29d39 commit 242fa44
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 18 deletions.
47 changes: 47 additions & 0 deletions guppylang/prelude/quantum.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,84 +26,128 @@ def __new__() -> "qubit":
reset(q)
return q

@guppy
@no_type_check
def measure(self: "qubit" @ owned) -> bool:
return measure(self)

@guppy
@no_type_check
def measure_return(self: "qubit") -> bool:
return measure_return(self)

@guppy
@no_type_check
def measure_reset(self: "qubit") -> bool:
"""Projective measure and reset without discarding the qubit."""
res = self.measure_return()
if res:
x(self)
return res

@guppy
@no_type_check
def discard(self: "qubit" @ owned) -> None:
discard(self)


@guppy.hugr_op(quantum_op("H"))
@no_type_check
def h(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("CZ"))
@no_type_check
def cz(control: qubit, target: qubit) -> None: ...


@guppy.hugr_op(quantum_op("CY"))
@no_type_check
def cy(control: qubit, target: qubit) -> None: ...


@guppy.hugr_op(quantum_op("CX"))
@no_type_check
def cx(control: qubit, target: qubit) -> None: ...


@guppy.hugr_op(quantum_op("T"))
@no_type_check
def t(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("S"))
@no_type_check
def s(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("X"))
@no_type_check
def x(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("Y"))
@no_type_check
def y(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("Z"))
@no_type_check
def z(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("Tdg"))
@no_type_check
def tdg(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("Sdg"))
@no_type_check
def sdg(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("ZZMax", ext=HSERIES_EXTENSION))
@no_type_check
def zz_max(q1: qubit, q2: qubit) -> None: ...


@guppy.custom(RotationCompiler("Rz"))
@no_type_check
def rz(q: qubit, angle: angle) -> None: ...


@guppy.custom(RotationCompiler("Rx"))
@no_type_check
def rx(q: qubit, angle: angle) -> None: ...


@guppy.custom(RotationCompiler("Ry"))
@no_type_check
def ry(q: qubit, angle: angle) -> None: ...


@guppy.custom(RotationCompiler("CRz"))
@no_type_check
def crz(control: qubit, target: qubit, angle: angle) -> None: ...


@guppy.hugr_op(quantum_op("Toffoli"))
@no_type_check
def toffoli(control1: qubit, control2: qubit, target: qubit) -> None: ...


@guppy.hugr_op(quantum_op("QAlloc"))
@no_type_check
def dirty_qubit() -> qubit: ...


@guppy.custom(MeasureReturnCompiler())
@no_type_check
def measure_return(q: qubit) -> bool: ...


@guppy.hugr_op(quantum_op("QFree"))
@no_type_check
def discard(q: qubit @ owned) -> None: ...


Expand Down Expand Up @@ -131,6 +175,7 @@ def zz_phase(q1: qubit, q2: qubit, angle: angle) -> None:


@guppy.hugr_op(quantum_op("Reset"))
@no_type_check
def reset(q: qubit) -> None: ...


Expand All @@ -140,6 +185,7 @@ def reset(q: qubit) -> None: ...


@guppy.hugr_op(quantum_op("PhasedX", ext=HSERIES_EXTENSION))
@no_type_check
def _phased_x(q: qubit, angle1: float, angle2: float) -> None:
"""PhasedX operation from the hseries extension.
Expand All @@ -149,6 +195,7 @@ def _phased_x(q: qubit, angle1: float, angle2: float) -> None:


@guppy.hugr_op(quantum_op("ZZPhase", ext=HSERIES_EXTENSION))
@no_type_check
def _zz_phase(q1: qubit, q2: qubit, angle: float) -> None:
"""ZZPhase operation from the hseries extension.
Expand Down
59 changes: 41 additions & 18 deletions tests/integration/test_inout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_basic(validate):
def foo(q: qubit) -> None: ...

@guppy(module)
def test(q: qubit @owned) -> qubit:
def test(q: qubit @ owned) -> qubit:
foo(q)
return q

Expand All @@ -29,10 +29,10 @@ def test_mixed(validate):
module.load_all(quantum)

@guppy.declare(module)
def foo(q1: qubit, q2: qubit @owned) -> qubit: ...
def foo(q1: qubit, q2: qubit @ owned) -> qubit: ...

@guppy(module)
def test(q1: qubit @owned, q2: qubit @owned) -> tuple[qubit, qubit]:
def test(q1: qubit @ owned, q2: qubit @ owned) -> tuple[qubit, qubit]:
q2 = foo(q1, q2)
return q1, q2

Expand All @@ -47,7 +47,7 @@ def test_local(validate):
def foo(q: qubit) -> None: ...

@guppy(module)
def test(q: qubit @owned) -> qubit:
def test(q: qubit @ owned) -> qubit:
f = foo
f(q)
return q
Expand All @@ -63,7 +63,7 @@ def test_nested_calls(validate):
def foo(x: int, q: qubit) -> int: ...

@guppy(module)
def test(q: qubit @owned) -> tuple[int, qubit]:
def test(q: qubit @ owned) -> tuple[int, qubit]:
# This is legal since function arguments and tuples are evaluated left to right
return foo(foo(foo(0, q), q), q), q

Expand All @@ -86,13 +86,13 @@ def foo(q1: qubit, q2: qubit) -> None: ...
def bar(a: MyStruct) -> None: ...

@guppy(module)
def test1(a: MyStruct @owned) -> MyStruct:
def test1(a: MyStruct @ owned) -> MyStruct:
foo(a.q1, a.q2)
bar(a)
return a

@guppy(module)
def test2(a: MyStruct @owned) -> MyStruct:
def test2(a: MyStruct @ owned) -> MyStruct:
bar(a)
foo(a.q1, a.q2)
bar(a)
Expand All @@ -112,7 +112,7 @@ def foo(q: qubit) -> None: ...
def bar(q: qubit) -> bool: ...

@guppy(module)
def test(q1: qubit @owned, q2: qubit @owned, n: int) -> tuple[qubit, qubit]:
def test(q1: qubit @ owned, q2: qubit @ owned, n: int) -> tuple[qubit, qubit]:
i = 0
while i < n:
foo(q1)
Expand Down Expand Up @@ -162,13 +162,15 @@ class C:
def foo(a: A, x: int) -> None: ...

@guppy.declare(module)
def bar(y: float, b: B, c: C @owned) -> C: ...
def bar(y: float, b: B, c: C @ owned) -> C: ...

@guppy.declare(module)
def baz(c: C) -> None: ...

@guppy(module)
def test(a: A @owned, b: B @owned, c1: C @owned, c2: C @owned, x: bool) -> tuple[A, B, C, C]:
def test(
a: A @ owned, b: B @ owned, c1: C @ owned, c2: C @ owned, x: bool
) -> tuple[A, B, C, C]:
c1 = (foo, bar, baz)(a, b.x, c1.x, b, c1, c2)
if x:
c1 = ((foo, bar), baz)(a, b.x, c1.x, b, c1, c2)
Expand All @@ -191,7 +193,7 @@ def foo(q: qubit) -> None:
h(q)

@guppy(module)
def test(q: qubit @owned) -> qubit:
def test(q: qubit @ owned) -> qubit:
foo(q)
foo(q)
return q
Expand All @@ -208,7 +210,7 @@ def test(q: qubit) -> None:
pass

@guppy(module)
def main(q: qubit @owned) -> qubit:
def main(q: qubit @ owned) -> qubit:
test(q)
return q

Expand All @@ -224,7 +226,7 @@ def foo(q: qubit) -> None: ...

@guppy(module)
def test(
b: int, c: qubit, d: float, a: tuple[qubit, qubit], e: qubit @owned
b: int, c: qubit, d: float, a: tuple[qubit, qubit], e: qubit @ owned
) -> tuple[qubit, float]:
foo(c)
return e, b + d
Expand All @@ -241,7 +243,7 @@ class MyStruct:
q: qubit

@guppy.declare(module)
def use(q: qubit @owned) -> None: ...
def use(q: qubit @ owned) -> None: ...

@guppy(module)
def foo(s: MyStruct) -> None:
Expand All @@ -257,7 +259,7 @@ def swap(s: MyStruct, t: MyStruct) -> None:
s.q, t.q = t.q, s.q

@guppy(module)
def main(s: MyStruct @owned, t: MyStruct @owned) -> MyStruct:
def main(s: MyStruct @ owned, t: MyStruct @ owned) -> MyStruct:
foo(s)
swap(s, t)
bar(t)
Expand All @@ -276,10 +278,12 @@ class MyStruct:
q: qubit

@guppy.declare(module)
def use(q: qubit @owned) -> None: ...
def use(q: qubit @ owned) -> None: ...

@guppy(module)
def test(s: MyStruct, b: bool, n: int, q1: qubit @owned, q2: qubit @owned) -> None:
def test(
s: MyStruct, b: bool, n: int, q1: qubit @ owned, q2: qubit @ owned
) -> None:
use(s.q)
if b:
s.q = q1
Expand All @@ -299,7 +303,7 @@ def test(s: MyStruct, b: bool, n: int, q1: qubit @owned, q2: qubit @owned) -> No
return

@guppy(module)
def main(s: MyStruct @owned) -> MyStruct:
def main(s: MyStruct @ owned) -> MyStruct:
test(s, False, 5, qubit(), qubit())
return s

Expand All @@ -325,6 +329,7 @@ def bar(self: "MyStruct", b: bool) -> None:

validate(module.compile())


def test_subtype(validate):
module = GuppyModule("test")
module.load_all(quantum)
Expand All @@ -340,6 +345,7 @@ def main() -> qubit:

validate(module.compile())


def test_shadow_check(validate):
module = GuppyModule("test")

Expand All @@ -354,3 +360,20 @@ def main(i: qubit) -> None:
foo(i)

validate(module.compile())


def test_self_qubit(validate):
module = GuppyModule("test")
module.load(qubit)

@guppy(module)
def test() -> bool:
q0 = qubit()

result = q0.measure_reset()
q0.measure_return()
q0.measure()
qubit().discard()
return result

validate(module.compile())

0 comments on commit 242fa44

Please sign in to comment.