diff --git a/README.md b/README.md index 35274be8..491cab0c 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,12 @@ Guppy is a quantum programming language that is fully embedded into Python. It allows you to write high-level hybrid quantum programs with classical control flow and mid-circuit measurements using Pythonic syntax: ```python -from guppylang import guppy, qubit, quantum +from guppylang import guppy +from guppylang.prelude.quantum import cx, h, measure, qubit, x, z -guppy.load_all(quantum) - - -# Teleports the state in `src` to `tgt`. @guppy def teleport(src: qubit, tgt: qubit) -> qubit: + """Teleports the state in `src` to `tgt`.""" # Create ancilla and entangle it with src and tgt tmp = qubit() tmp, tgt = cx(h(tmp), tgt) @@ -31,6 +29,8 @@ def teleport(src: qubit, tgt: qubit) -> qubit: if measure(tmp): tgt = x(tgt) return tgt + +guppy.compile_module() ``` More examples and tutorials are available [here][examples]. diff --git a/examples/random_walk_qpe.py b/examples/random_walk_qpe.py index 8e8c80e3..984d76ac 100644 --- a/examples/random_walk_qpe.py +++ b/examples/random_walk_qpe.py @@ -8,22 +8,17 @@ import math from collections.abc import Callable -import guppylang.prelude.quantum as quantum from guppylang.decorator import guppy -from guppylang.module import GuppyModule from guppylang.prelude.angles import angle from guppylang.prelude.builtins import py, result from guppylang.prelude.quantum import cx, discard, h, measure, qubit, rz, x -module = GuppyModule("test") -module.load_all(quantum) -module.load(angle) sqrt_e = math.sqrt(math.e) sqrt_e_div = math.sqrt((math.e - 1) / math.e) -@guppy(module) +@guppy def random_walk_phase_estimation( eigenstate: Callable[[], qubit], controlled_oracle: Callable[[qubit, qubit, float], tuple[qubit, qubit]], @@ -66,7 +61,7 @@ def random_walk_phase_estimation( return mu -@guppy(module) +@guppy def example_controlled_oracle(q1: qubit, q2: qubit, t: float) -> tuple[qubit, qubit]: """A controlled e^itH gate for the example Hamiltonian H = -0.5 * Z""" # This is just a controlled rz gate @@ -77,14 +72,14 @@ def example_controlled_oracle(q1: qubit, q2: qubit, t: float) -> tuple[qubit, qu return cx(q1, q2) -@guppy(module) +@guppy def example_eigenstate() -> qubit: """The eigenstate of e^itH for the example Hamiltonian H = -0.5 * Z""" # This is just |1> return x(qubit()) -@guppy(module) +@guppy def main() -> int: num_iters = 24 # To avoid underflows reset_rate = 8 @@ -102,4 +97,4 @@ def main() -> int: return 0 -hugr = module.compile() +hugr = guppy.compile_module() diff --git a/examples/t_factory.py b/examples/t_factory.py index 88bf7b6b..edab580b 100644 --- a/examples/t_factory.py +++ b/examples/t_factory.py @@ -1,7 +1,6 @@ import numpy as np from guppylang.decorator import guppy -from guppylang.module import GuppyModule from guppylang.prelude.angles import angle, pi from guppylang.prelude.builtins import linst, py from guppylang.prelude.quantum import ( @@ -9,20 +8,16 @@ discard, h, measure, - quantum, qubit, rx, rz, ) -module = GuppyModule("t_factory") -module.load_all(quantum) -module.load(angle, pi) phi = np.arccos(1 / 3) -@guppy(module) +@guppy def ry(q: qubit, theta: angle) -> qubit: q = rx(q, pi / 2) q = rz(q, theta + pi) @@ -31,14 +26,14 @@ def ry(q: qubit, theta: angle) -> qubit: # Preparation of approximate T state, from https://arxiv.org/abs/2310.12106 -@guppy(module) +@guppy def prepare_approx(q: qubit) -> qubit: q = ry(q, angle(py(phi))) return rz(q, pi / 4) # The inverse of the [[5,3,1]] encoder in figure 3 of https://arxiv.org/abs/2208.01863 -@guppy(module) +@guppy def distill( target: qubit, q0: qubit, q1: qubit, q2: qubit, q3: qubit ) -> tuple[qubit, bool]: @@ -61,7 +56,7 @@ def distill( return target, success -@guppy(module) +@guppy def t_state(timeout: int) -> tuple[linst[qubit], bool]: """Create a T state using magic state distillation with `timeout` attempts. @@ -91,4 +86,4 @@ def t_state(timeout: int) -> tuple[linst[qubit], bool]: return [], False -hugr = module.compile() +hugr = guppy.compile_module() diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 87442de4..78aa57d6 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -1,6 +1,6 @@ import ast import inspect -from collections.abc import Callable +from collections.abc import Callable, KeysView from dataclasses import dataclass, field from pathlib import Path from types import FrameType, ModuleType @@ -11,8 +11,9 @@ from hugr import tys as ht from hugr import val as hv +import guppylang from guppylang.ast_util import annotate_location, has_empty_body -from guppylang.definition.common import DefId +from guppylang.definition.common import DefId, Definition from guppylang.definition.const import RawConstDef from guppylang.definition.custom import ( CustomCallChecker, @@ -29,7 +30,7 @@ from guppylang.definition.struct import RawStructDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import GuppyError, MissingModuleError, pretty_errors -from guppylang.module import GuppyModule, PyFunc +from guppylang.module import GuppyModule, PyFunc, find_guppy_module_in_py_module from guppylang.tys.subst import Inst from guppylang.tys.ty import NumericType @@ -52,13 +53,14 @@ class ModuleIdentifier: #: module, we only take the module path into account. name: str = field(compare=False) + #: A reference to the python module + module: ModuleType | None = field(compare=False) + class _Guppy: """Class for the `@guppy` decorator.""" - # The currently-alive GuppyModules, associated with a Python file/module. - # - # Only contains **uncompiled** modules. + # The currently-alive GuppyModules, associated with a Python file/module _modules: dict[ModuleIdentifier, GuppyModule] def __init__(self) -> None: @@ -81,11 +83,7 @@ def __call__(self, arg: PyFunc | GuppyModule) -> FuncDefDecorator | RawFunctionD # Decorator used without any arguments. # We default to a module associated with the caller of the decorator. f = arg - - caller = self._get_python_caller(f) - if caller not in self._modules: - self._modules[caller] = GuppyModule(caller.name) - module = self._modules[caller] + module = self.get_module() return module.register_func_def(f) if isinstance(arg, GuppyModule): @@ -110,12 +108,14 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier: if s.filename != __file__: filename = s.filename module = inspect.getmodule(s.frame) - break + # Skip frames from the `pretty_error` decorator + if module != guppylang.error: + break else: raise GuppyError("Could not find a caller for the `@guppy` decorator") module_path = Path(filename) return ModuleIdentifier( - module_path, module.__name__ if module else module_path.name + module_path, module.__name__ if module else module_path.name, module ) @pretty_errors @@ -291,23 +291,32 @@ def load(self, m: ModuleType | GuppyModule) -> None: module = self._modules[caller] module.load_all(m) - def take_module(self, id: ModuleIdentifier | None = None) -> GuppyModule: - """Returns the local GuppyModule, removing it from the local state.""" - orig_id = id + def get_module(self, id: ModuleIdentifier | None = None) -> GuppyModule: + """Returns the local GuppyModule.""" if id is None: id = self._get_python_caller() if id not in self._modules: - err = ( - f"Module {orig_id.name} not found." - if orig_id - else "No Guppy functions or types defined in this module." - ) - raise MissingModuleError(err) - return self._modules.pop(id) + self._modules[id] = GuppyModule(id.name.split(".")[-1]) + module = self._modules[id] + # Update implicit imports + if id.module: + defs: dict[str, Definition | ModuleType] = {} + for x, value in id.module.__dict__.items(): + if isinstance(value, Definition) and value.id.module != module: + defs[x] = value + elif isinstance(value, ModuleType): + try: + other_module = find_guppy_module_in_py_module(value) + if other_module and other_module != module: + defs[x] = value + except GuppyError: + pass + module.load(**defs) + return module def compile_module(self, id: ModuleIdentifier | None = None) -> hugr.ext.Package: """Compiles the local module into a Hugr.""" - module = self.take_module(id) + module = self.get_module(id) if not module: err = ( f"Module {id.name} not found." @@ -317,9 +326,9 @@ def compile_module(self, id: ModuleIdentifier | None = None) -> hugr.ext.Package raise MissingModuleError(err) return module.compile() - def registered_modules(self) -> list[ModuleIdentifier]: + def registered_modules(self) -> KeysView[ModuleIdentifier]: """Returns a list of all currently registered modules for local contexts.""" - return list(self._modules.keys()) + return self._modules.keys() guppy = _Guppy() diff --git a/guppylang/module.py b/guppylang/module.py index 028a688b..d7d67fec 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -1,6 +1,7 @@ import inspect import sys from collections.abc import Callable, Mapping +from pathlib import Path from types import ModuleType from typing import Any @@ -406,6 +407,14 @@ def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule: Raises a user-error if no unique module can be found. """ mods = [val for val in module.__dict__.values() if isinstance(val, GuppyModule)] + # Also include implicit modules + from guppylang.decorator import ModuleIdentifier, guppy + + if hasattr(module, "__file__") and module.__file__: + module_id = ModuleIdentifier(Path(module.__file__), module.__name__, module) + if module_id in guppy.registered_modules(): + mods.append(guppy.get_module(module_id)) + if not mods: msg = f"No Guppy modules found in `{module.__name__}`" raise GuppyError(msg) diff --git a/quickstart.md b/quickstart.md index 40e398ab..3cdca383 100644 --- a/quickstart.md +++ b/quickstart.md @@ -3,14 +3,12 @@ allows you to write high-level hybrid quantum programs with classical control flow and mid-circuit measurements using Pythonic syntax: ```python -from guppylang import guppy, qubit, quantum +from guppylang import guppy +from guppylang.prelude.quantum import cx, h, measure, qubit, x, z -guppy.load_all(quantum) - - -# Teleports the state in `src` to `tgt`. @guppy def teleport(src: qubit, tgt: qubit) -> qubit: + """Teleports the state in `src` to `tgt`.""" # Create ancilla and entangle it with src and tgt tmp = qubit() tmp, tgt = cx(h(tmp), tgt) @@ -22,4 +20,6 @@ def teleport(src: qubit, tgt: qubit) -> qubit: if measure(tmp): tgt = x(tgt) return tgt + +guppy.compile_module() ``` diff --git a/tests/error/misc_errors/implicit_module_error.err b/tests/error/misc_errors/implicit_module_error.err new file mode 100644 index 00000000..c2afc3f8 --- /dev/null +++ b/tests/error/misc_errors/implicit_module_error.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return 1.0 + ^^^ +GuppyTypeError: Expected return value of type `int`, got `float` diff --git a/tests/error/misc_errors/implicit_module_error.py b/tests/error/misc_errors/implicit_module_error.py new file mode 100644 index 00000000..1821b448 --- /dev/null +++ b/tests/error/misc_errors/implicit_module_error.py @@ -0,0 +1,9 @@ +from guppylang import guppy + + +@guppy +def foo() -> int: + return 1.0 + + +guppy.compile_module() diff --git a/tests/integration/modules/implicit_mod.py b/tests/integration/modules/implicit_mod.py new file mode 100644 index 00000000..232f8a3b --- /dev/null +++ b/tests/integration/modules/implicit_mod.py @@ -0,0 +1,6 @@ +from guppylang import guppy + + +@guppy +def foo(x: int) -> int: + return x + 1 diff --git a/tests/integration/test_decorator.py b/tests/integration/test_decorator.py index 09d9a073..80dc48b0 100644 --- a/tests/integration/test_decorator.py +++ b/tests/integration/test_decorator.py @@ -17,7 +17,7 @@ def b() -> None: def c() -> None: pass - default_module = guppy.take_module() + default_module = guppy.get_module() assert not module.contains("a") assert module.contains("b") @@ -34,7 +34,7 @@ def make_module() -> GuppyModule: def a() -> None: pass - return guppy.take_module() + return guppy.get_module() module_a = make_module() module_b = make_module() diff --git a/tests/integration/test_docstring.py b/tests/integration/test_docstring.py index 80227f30..c9312177 100644 --- a/tests/integration/test_docstring.py +++ b/tests/integration/test_docstring.py @@ -32,7 +32,7 @@ def g_nested() -> None: string. """ - default_module = guppy.take_module() + default_module = guppy.get_module() validate(default_module.compile()) diff --git a/tests/integration/test_imports.py b/tests/integration/test_imports.py index 38020603..49443740 100644 --- a/tests/integration/test_imports.py +++ b/tests/integration/test_imports.py @@ -28,6 +28,19 @@ def test(x: MyType) -> MyType: validate(module.compile()) +def test_import_implicit(validate): + from tests.integration.modules.implicit_mod import foo + + module = GuppyModule("test") + module.load(foo) + + @guppy(module) + def test(x: int) -> int: + return foo(x) + + validate(module.compile()) + + def test_func_alias(validate): from tests.integration.modules.mod_a import f as g @@ -151,3 +164,16 @@ def test(x: mod_a.MyType, y: mod_b.MyType) -> tuple[mod_a.MyType, mod_b.MyType]: return -x, +y validate(module.compile()) + + +def test_qualified_implicit(validate): + import tests.integration.modules.implicit_mod as implicit_mod + + module = GuppyModule("test") + module.load(implicit_mod) + + @guppy(module) + def test(x: int) -> int: + return implicit_mod.foo(x) + + validate(module.compile())