diff --git a/guppylang/std/_internal/compiler/prelude.py b/guppylang/std/_internal/compiler/prelude.py index e5dcfbba..b78d5ea9 100644 --- a/guppylang/std/_internal/compiler/prelude.py +++ b/guppylang/std/_internal/compiler/prelude.py @@ -11,6 +11,10 @@ from hugr import tys as ht from hugr import val as hv +from guppylang.definition.custom import CustomCallCompiler +from guppylang.definition.value import CallReturnWires +from guppylang.error import InternalGuppyError + if TYPE_CHECKING: from hugr.build.dfg import DfBase @@ -123,3 +127,14 @@ def build_unwrap( result is an error. """ return build_unwrap_right(builder, result, error_msg, error_signal) + + +class MemSwapCompiler(CustomCallCompiler): + """Compiler for the `mem_swap` function.""" + + def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: + [x, y] = args + return CallReturnWires(regular_returns=[], inout_returns=[y, x]) + + def compile(self, args: list[Wire]) -> list[Wire]: + raise InternalGuppyError("Call compile_with_inouts instead") diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index 250a04e2..4f943e73 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -41,6 +41,7 @@ ListPushCompiler, ListSetitemCompiler, ) +from guppylang.std._internal.compiler.prelude import MemSwapCompiler from guppylang.std._internal.util import ( float_op, int_op, @@ -880,3 +881,8 @@ def zip(x): ... @guppy.custom(checker=UnsupportedChecker(), higher_order_value=False) def __import__(x): ... + + +@guppy.custom(MemSwapCompiler()) +def mem_swap(x: L, y: L) -> None: + """Swaps the values of two variables.""" diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 24cfa5d3..664925bb 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -3,7 +3,7 @@ from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.std.builtins import array, owned +from guppylang.std.builtins import array, owned, mem_swap from tests.util import compile_guppy from guppylang.std.quantum import qubit @@ -258,3 +258,21 @@ def main() -> int: package = module.compile() validate(package) run_int_fn(package, expected=6) + + +def test_mem_swap(validate): + module = GuppyModule("test") + + module.load(qubit) + @guppy(module) + def foo(x: qubit, y: qubit) -> None: + mem_swap(x, y) + + @guppy(module) + def main() -> array[qubit, 2]: + a = array(qubit(), qubit()) + foo(a[0], a[1]) + return a + + package = module.compile() + validate(package)