Skip to content

Commit

Permalink
feat!: Add qualified imports and make them the default (#443)
Browse files Browse the repository at this point in the history
Closes #426.

BREAKING CHANGE: `GuppyModule.load` no longer loads the content of
modules but instead just brings the name of the module into scope. Use
`GuppyModule.load_all` to get the old behaviour.
  • Loading branch information
mark-koch authored Sep 3, 2024
1 parent a255c02 commit 553ec51
Show file tree
Hide file tree
Showing 114 changed files with 500 additions and 265 deletions.
25 changes: 13 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@ It allows you to write high-level hybrid quantum programs with classical control
```python
from guppylang import guppy, qubit, quantum

guppy.load(quantum)
guppy.load_all(quantum)


# Teleports the state in `src` to `tgt`.
@guppy
def teleport(src: qubit, tgt: qubit) -> qubit:
# Create ancilla and entangle it with src and tgt
tmp = qubit()
tmp, tgt = cx(h(tmp), tgt)
src, tmp = cx(src, tmp)

# Apply classical corrections
if measure(h(src)):
tgt = z(tgt)
if measure(tmp):
tgt = x(tgt)
return tgt
# Create ancilla and entangle it with src and tgt
tmp = qubit()
tmp, tgt = cx(h(tmp), tgt)
src, tmp = cx(src, tmp)

# Apply classical corrections
if measure(h(src)):
tgt = z(tgt)
if measure(tmp):
tgt = x(tgt)
return tgt
```

More examples and tutorials are available [here][examples].
Expand Down
12 changes: 6 additions & 6 deletions examples/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@
"metadata": {},
"outputs": [],
"source": [
"module.load(guppylang.prelude.quantum)\n",
"from guppylang.prelude.quantum import qubit, h, cx, measure\n",
"\n",
"from guppylang.prelude.quantum import qubit, h, rz, measure\n",
"module.load(qubit, h, cx, measure)\n",
"\n",
"@guppy(module)\n",
"def bell() -> bool:\n",
Expand Down Expand Up @@ -305,7 +305,7 @@
],
"source": [
"bad_module = GuppyModule(name=\"bad_module\")\n",
"bad_module.load(guppylang.prelude.quantum)\n",
"bad_module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit) -> tuple[qubit, qubit]:\n",
Expand Down Expand Up @@ -348,7 +348,7 @@
],
"source": [
"bad_module = GuppyModule(name=\"bad_module\")\n",
"bad_module.load(guppylang.prelude.quantum)\n",
"bad_module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy(bad_module)\n",
"def bad(q: qubit) -> qubit:\n",
Expand Down Expand Up @@ -412,7 +412,7 @@
"outputs": [],
"source": [
"module = GuppyModule(\"structs\")\n",
"module.load(guppylang.prelude.quantum)\n",
"module.load_all(guppylang.prelude.quantum)\n",
"\n",
"@guppy.struct(module)\n",
"class QubitPair:\n",
Expand Down Expand Up @@ -557,7 +557,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/random_walk_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from guppylang.prelude.quantum import cx, discard, h, measure, qubit, rz, x

module = GuppyModule("test")
module.load(quantum)
module.load_all(quantum)

sqrt_e = math.sqrt(math.e)
sqrt_e_div = math.sqrt((math.e - 1) / math.e)
Expand Down
2 changes: 1 addition & 1 deletion examples/t_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

module = GuppyModule("t_factory")
module.load(quantum)
module.load_all(quantum)

phi = np.arccos(1 / 3)
pi = np.pi
Expand Down
50 changes: 37 additions & 13 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
Locals,
Variable,
)
from guppylang.definition.common import Definition
from guppylang.definition.module import ModuleDef
from guppylang.definition.ty import TypeDef
from guppylang.definition.value import CallableDef, ValueDef
from guppylang.error import (
Expand Down Expand Up @@ -345,26 +347,40 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.expr, Type]:
var = self.ctx.locals[x]
return with_loc(node, PlaceNode(place=var)), var.ty
elif x in self.ctx.globals:
# Look-up what kind of definition it is
match self.ctx.globals[x]:
case ValueDef() as defn:
return with_loc(node, GlobalName(id=x, def_id=defn.id)), defn.ty
# For types, we return their `__new__` constructor
case TypeDef() as defn if constr := self.ctx.globals.get_instance_func(
defn, "__new__"
):
return with_loc(node, GlobalName(id=x, def_id=constr.id)), constr.ty
case defn:
raise GuppyError(
f"Expected a value, got {defn.description} `{x}`", node
)
defn = self.ctx.globals[x]
return self._check_global(defn, x, node)
raise InternalGuppyError(
f"Variable `{x}` is not defined in `TypeSynthesiser`. This should have "
"been caught by program analysis!"
)

def _check_global(
self, defn: Definition, name: str, node: ast.expr
) -> tuple[ast.expr, Type]:
"""Checks a global definition in an expression position."""
match defn:
case ValueDef() as defn:
return with_loc(node, GlobalName(id=name, def_id=defn.id)), defn.ty
# For types, we return their `__new__` constructor
case TypeDef() as defn if constr := self.ctx.globals.get_instance_func(
defn, "__new__"
):
return with_loc(node, GlobalName(id=name, def_id=constr.id)), constr.ty
case defn:
raise GuppyError(
f"Expected a value, got {defn.description} `{name}`", node
)

def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
# A `value.attr` attribute access
if module_def := self._is_module_def(node.value):
if node.attr not in module_def.globals:
raise GuppyError(
f"Module `{module_def.name}` has no member `{node.attr}`", node
)
defn = module_def.globals[node.attr]
qual_name = f"{module_def.name}.{defn.name}"
return self._check_global(defn, qual_name, node)
node.value, ty = self.synthesize(node.value)
if isinstance(ty, StructType) and node.attr in ty.field_dict:
field = ty.field_dict[node.attr]
Expand Down Expand Up @@ -398,6 +414,14 @@ def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
node,
)

def _is_module_def(self, node: ast.expr) -> ModuleDef | None:
"""Checks whether an AST node corresponds to a defined module."""
if isinstance(node, ast.Name) and node.id in self.ctx.globals:
defn = self.ctx.globals[node.id]
if isinstance(defn, ModuleDef):
return defn
return None

def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]:
elems = [self.synthesize(elem) for elem in node.elts]

Expand Down
2 changes: 1 addition & 1 deletion guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def load(self, m: ModuleType | GuppyModule) -> None:
if caller not in self._modules:
self._modules[caller] = GuppyModule(caller.name)
module = self._modules[caller]
module.load(m)
module.load_all(m)

def take_module(self, id: ModuleIdentifier | None = None) -> GuppyModule:
"""Returns the local GuppyModule, removing it from the local state."""
Expand Down
22 changes: 22 additions & 0 deletions guppylang/definition/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from guppylang.definition.common import CompiledDef

if TYPE_CHECKING:
from guppylang.checker.core import Globals


@dataclass(frozen=True)
class ModuleDef(CompiledDef):
"""A module definition.
Note that this definition is separate from the `GuppyModule` class and only serves
as a pointer to be stored in the globals.
In the future we could consider unifying this with `GuppyModule`.
"""

globals: "Globals"

description: str = field(default="module", init=False)
68 changes: 47 additions & 21 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from guppylang.definition.declaration import RawFunctionDecl
from guppylang.definition.function import RawFunctionDef
from guppylang.definition.module import ModuleDef
from guppylang.definition.parameter import ParamDef
from guppylang.definition.struct import CheckedStructDef
from guppylang.definition.ty import TypeDef
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, name: str, import_builtins: bool = True):
if import_builtins:
import guppylang.prelude.builtins as builtins

self.load(builtins)
self.load_all(builtins)

def load(
self,
Expand Down Expand Up @@ -113,28 +114,17 @@ def load(
names[alias or imp.name] = imp.id
modules.add(module)
elif isinstance(imp, GuppyModule):
# TODO: In the future this should be a qualified module import. For now
# we just import all contained definitions that don't start with `_`
imp.check()
imports.extend(
(defn.name, defn)
for defn in imp._globals.defs.values()
if not defn.name.startswith("_")
)
def_id = DefId.fresh(imp)
name = alias or imp.name
defn = ModuleDef(def_id, name, None, imp._globals)
defs[def_id] = defn
names[name] = def_id
defs |= imp._checked_defs
modules.add(imp)
elif isinstance(imp, ModuleType):
mods = [
val for val in imp.__dict__.values() if isinstance(val, GuppyModule)
]
if not mods:
msg = f"No Guppy modules found in `{imp.__name__}`"
raise GuppyError(msg)
if len(mods) > 1:
msg = (
f"Python module `{imp.__name__}` contains multiple Guppy "
"modules. Cannot decide which one to import."
)
raise GuppyError(msg)
imports.append((alias, mods[0]))
mod = find_guppy_module_in_py_module(imp)
imports.append((alias, mod))
else:
msg = f"Only Guppy definitions or modules can be imported. Got `{imp}`"
raise GuppyError(msg)
Expand All @@ -160,6 +150,23 @@ def load(
for module in modules:
self._imported_checked_defs |= module._imported_checked_defs

def load_all(self, mod: "GuppyModule | ModuleType") -> None:
"""Imports all public members of a module."""
if isinstance(mod, GuppyModule):
mod.check()
self.load(
*(
defn
for defn in mod._globals.defs.values()
if not defn.name.startswith("_")
)
)
elif isinstance(mod, ModuleType):
self.load_all(find_guppy_module_in_py_module(mod))
else:
msg = f"Only Guppy definitions or modules can be imported. Got `{mod}`"
raise GuppyError(msg)

def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None:
"""Registers a definition with this module.
Expand Down Expand Up @@ -335,3 +342,22 @@ def get_py_scope(f: PyFunc) -> PyScope:
nonlocals[var] = value

return nonlocals | f.__globals__.copy()


def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule:
"""Helper function to search the `__dict__` of a Python module for an instance of
`GuppyModule`.
Raises a user-error if no unique module can be found.
"""
mods = [val for val in module.__dict__.values() if isinstance(val, GuppyModule)]
if not mods:
msg = f"No Guppy modules found in `{module.__name__}`"
raise GuppyError(msg)
if len(mods) > 1:
msg = (
f"Python module `{module.__name__}` contains multiple Guppy modules. "
"Cannot decide which one to import."
)
raise GuppyError(msg)
return mods[0]
Loading

0 comments on commit 553ec51

Please sign in to comment.