Skip to content

Commit

Permalink
feat!: Add implicit importing of modules (#461)
Browse files Browse the repository at this point in the history
Closes #133 and closes #462.

* When using implicit modules, automatically import definitions by
inspecting the `__dict__` of the containing Python module.
* Implicit modules are no longer discarded once they are compiled for
the first time.
* Updated examples to use implicit modules

BREAKING CHANGE: `guppy.take_module` renamed to `guppy.get_module` and
no longer removes the module from the state.
  • Loading branch information
mark-koch authored Sep 11, 2024
1 parent ce0f746 commit 1b73032
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 59 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ Guppy is a quantum programming language that is fully embedded into Python.
It allows you to write high-level hybrid quantum programs with classical control flow and mid-circuit measurements using Pythonic syntax:

```python
from guppylang import guppy, qubit, quantum
from guppylang import guppy
from guppylang.prelude.quantum import cx, h, measure, qubit, x, z

guppy.load_all(quantum)


# Teleports the state in `src` to `tgt`.
@guppy
def teleport(src: qubit, tgt: qubit) -> qubit:
"""Teleports the state in `src` to `tgt`."""
# Create ancilla and entangle it with src and tgt
tmp = qubit()
tmp, tgt = cx(h(tmp), tgt)
Expand All @@ -31,6 +29,8 @@ def teleport(src: qubit, tgt: qubit) -> qubit:
if measure(tmp):
tgt = x(tgt)
return tgt

guppy.compile_module()
```

More examples and tutorials are available [here][examples].
Expand Down
15 changes: 5 additions & 10 deletions examples/random_walk_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@
import math
from collections.abc import Callable

import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.angles import angle
from guppylang.prelude.builtins import py, result
from guppylang.prelude.quantum import cx, discard, h, measure, qubit, rz, x

module = GuppyModule("test")
module.load_all(quantum)
module.load(angle)

sqrt_e = math.sqrt(math.e)
sqrt_e_div = math.sqrt((math.e - 1) / math.e)


@guppy(module)
@guppy
def random_walk_phase_estimation(
eigenstate: Callable[[], qubit],
controlled_oracle: Callable[[qubit, qubit, float], tuple[qubit, qubit]],
Expand Down Expand Up @@ -66,7 +61,7 @@ def random_walk_phase_estimation(
return mu


@guppy(module)
@guppy
def example_controlled_oracle(q1: qubit, q2: qubit, t: float) -> tuple[qubit, qubit]:
"""A controlled e^itH gate for the example Hamiltonian H = -0.5 * Z"""
# This is just a controlled rz gate
Expand All @@ -77,14 +72,14 @@ def example_controlled_oracle(q1: qubit, q2: qubit, t: float) -> tuple[qubit, qu
return cx(q1, q2)


@guppy(module)
@guppy
def example_eigenstate() -> qubit:
"""The eigenstate of e^itH for the example Hamiltonian H = -0.5 * Z"""
# This is just |1>
return x(qubit())


@guppy(module)
@guppy
def main() -> int:
num_iters = 24 # To avoid underflows
reset_rate = 8
Expand All @@ -102,4 +97,4 @@ def main() -> int:
return 0


hugr = module.compile()
hugr = guppy.compile_module()
15 changes: 5 additions & 10 deletions examples/t_factory.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
import numpy as np

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.angles import angle, pi
from guppylang.prelude.builtins import linst, py
from guppylang.prelude.quantum import (
cz,
discard,
h,
measure,
quantum,
qubit,
rx,
rz,
)

module = GuppyModule("t_factory")
module.load_all(quantum)
module.load(angle, pi)

phi = np.arccos(1 / 3)


@guppy(module)
@guppy
def ry(q: qubit, theta: angle) -> qubit:
q = rx(q, pi / 2)
q = rz(q, theta + pi)
Expand All @@ -31,14 +26,14 @@ def ry(q: qubit, theta: angle) -> qubit:


# Preparation of approximate T state, from https://arxiv.org/abs/2310.12106
@guppy(module)
@guppy
def prepare_approx(q: qubit) -> qubit:
q = ry(q, angle(py(phi)))
return rz(q, pi / 4)


# The inverse of the [[5,3,1]] encoder in figure 3 of https://arxiv.org/abs/2208.01863
@guppy(module)
@guppy
def distill(
target: qubit, q0: qubit, q1: qubit, q2: qubit, q3: qubit
) -> tuple[qubit, bool]:
Expand All @@ -61,7 +56,7 @@ def distill(
return target, success


@guppy(module)
@guppy
def t_state(timeout: int) -> tuple[linst[qubit], bool]:
"""Create a T state using magic state distillation with `timeout` attempts.
Expand Down Expand Up @@ -91,4 +86,4 @@ def t_state(timeout: int) -> tuple[linst[qubit], bool]:
return [], False


hugr = module.compile()
hugr = guppy.compile_module()
61 changes: 35 additions & 26 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
import inspect
from collections.abc import Callable
from collections.abc import Callable, KeysView
from dataclasses import dataclass, field
from pathlib import Path
from types import FrameType, ModuleType
Expand All @@ -11,8 +11,9 @@
from hugr import tys as ht
from hugr import val as hv

import guppylang
from guppylang.ast_util import annotate_location, has_empty_body
from guppylang.definition.common import DefId
from guppylang.definition.common import DefId, Definition
from guppylang.definition.const import RawConstDef
from guppylang.definition.custom import (
CustomCallChecker,
Expand All @@ -29,7 +30,7 @@
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError, MissingModuleError, pretty_errors
from guppylang.module import GuppyModule, PyFunc
from guppylang.module import GuppyModule, PyFunc, find_guppy_module_in_py_module
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType

Expand All @@ -52,13 +53,14 @@ class ModuleIdentifier:
#: module, we only take the module path into account.
name: str = field(compare=False)

#: A reference to the python module
module: ModuleType | None = field(compare=False)


class _Guppy:
"""Class for the `@guppy` decorator."""

# The currently-alive GuppyModules, associated with a Python file/module.
#
# Only contains **uncompiled** modules.
# The currently-alive GuppyModules, associated with a Python file/module
_modules: dict[ModuleIdentifier, GuppyModule]

def __init__(self) -> None:
Expand All @@ -81,11 +83,7 @@ def __call__(self, arg: PyFunc | GuppyModule) -> FuncDefDecorator | RawFunctionD
# Decorator used without any arguments.
# We default to a module associated with the caller of the decorator.
f = arg

caller = self._get_python_caller(f)
if caller not in self._modules:
self._modules[caller] = GuppyModule(caller.name)
module = self._modules[caller]
module = self.get_module()
return module.register_func_def(f)

if isinstance(arg, GuppyModule):
Expand All @@ -110,12 +108,14 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier:
if s.filename != __file__:
filename = s.filename
module = inspect.getmodule(s.frame)
break
# Skip frames from the `pretty_error` decorator
if module != guppylang.error:
break
else:
raise GuppyError("Could not find a caller for the `@guppy` decorator")
module_path = Path(filename)
return ModuleIdentifier(
module_path, module.__name__ if module else module_path.name
module_path, module.__name__ if module else module_path.name, module
)

@pretty_errors
Expand Down Expand Up @@ -291,23 +291,32 @@ def load(self, m: ModuleType | GuppyModule) -> None:
module = self._modules[caller]
module.load_all(m)

def take_module(self, id: ModuleIdentifier | None = None) -> GuppyModule:
"""Returns the local GuppyModule, removing it from the local state."""
orig_id = id
def get_module(self, id: ModuleIdentifier | None = None) -> GuppyModule:
"""Returns the local GuppyModule."""
if id is None:
id = self._get_python_caller()
if id not in self._modules:
err = (
f"Module {orig_id.name} not found."
if orig_id
else "No Guppy functions or types defined in this module."
)
raise MissingModuleError(err)
return self._modules.pop(id)
self._modules[id] = GuppyModule(id.name.split(".")[-1])
module = self._modules[id]
# Update implicit imports
if id.module:
defs: dict[str, Definition | ModuleType] = {}
for x, value in id.module.__dict__.items():
if isinstance(value, Definition) and value.id.module != module:
defs[x] = value
elif isinstance(value, ModuleType):
try:
other_module = find_guppy_module_in_py_module(value)
if other_module and other_module != module:
defs[x] = value
except GuppyError:
pass
module.load(**defs)
return module

def compile_module(self, id: ModuleIdentifier | None = None) -> hugr.ext.Package:
"""Compiles the local module into a Hugr."""
module = self.take_module(id)
module = self.get_module(id)
if not module:
err = (
f"Module {id.name} not found."
Expand All @@ -317,9 +326,9 @@ def compile_module(self, id: ModuleIdentifier | None = None) -> hugr.ext.Package
raise MissingModuleError(err)
return module.compile()

def registered_modules(self) -> list[ModuleIdentifier]:
def registered_modules(self) -> KeysView[ModuleIdentifier]:
"""Returns a list of all currently registered modules for local contexts."""
return list(self._modules.keys())
return self._modules.keys()


guppy = _Guppy()
Expand Down
9 changes: 9 additions & 0 deletions guppylang/module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import sys
from collections.abc import Callable, Mapping
from pathlib import Path
from types import ModuleType
from typing import Any

Expand Down Expand Up @@ -406,6 +407,14 @@ def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule:
Raises a user-error if no unique module can be found.
"""
mods = [val for val in module.__dict__.values() if isinstance(val, GuppyModule)]
# Also include implicit modules
from guppylang.decorator import ModuleIdentifier, guppy

if hasattr(module, "__file__") and module.__file__:
module_id = ModuleIdentifier(Path(module.__file__), module.__name__, module)
if module_id in guppy.registered_modules():
mods.append(guppy.get_module(module_id))

if not mods:
msg = f"No Guppy modules found in `{module.__name__}`"
raise GuppyError(msg)
Expand Down
10 changes: 5 additions & 5 deletions quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ allows you to write high-level hybrid quantum programs with classical control
flow and mid-circuit measurements using Pythonic syntax:

```python
from guppylang import guppy, qubit, quantum
from guppylang import guppy
from guppylang.prelude.quantum import cx, h, measure, qubit, x, z

guppy.load_all(quantum)


# Teleports the state in `src` to `tgt`.
@guppy
def teleport(src: qubit, tgt: qubit) -> qubit:
"""Teleports the state in `src` to `tgt`."""
# Create ancilla and entangle it with src and tgt
tmp = qubit()
tmp, tgt = cx(h(tmp), tgt)
Expand All @@ -22,4 +20,6 @@ def teleport(src: qubit, tgt: qubit) -> qubit:
if measure(tmp):
tgt = x(tgt)
return tgt

guppy.compile_module()
```
7 changes: 7 additions & 0 deletions tests/error/misc_errors/implicit_module_error.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
5: def foo() -> int:
6: return 1.0
^^^
GuppyTypeError: Expected return value of type `int`, got `float`
9 changes: 9 additions & 0 deletions tests/error/misc_errors/implicit_module_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from guppylang import guppy


@guppy
def foo() -> int:
return 1.0


guppy.compile_module()
6 changes: 6 additions & 0 deletions tests/integration/modules/implicit_mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from guppylang import guppy


@guppy
def foo(x: int) -> int:
return x + 1
4 changes: 2 additions & 2 deletions tests/integration/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def b() -> None:
def c() -> None:
pass

default_module = guppy.take_module()
default_module = guppy.get_module()

assert not module.contains("a")
assert module.contains("b")
Expand All @@ -34,7 +34,7 @@ def make_module() -> GuppyModule:
def a() -> None:
pass

return guppy.take_module()
return guppy.get_module()

module_a = make_module()
module_b = make_module()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def g_nested() -> None:
string.
"""

default_module = guppy.take_module()
default_module = guppy.get_module()
validate(default_module.compile())


Expand Down
Loading

0 comments on commit 1b73032

Please sign in to comment.