Skip to content

Commit

Permalink
feat: Local implicit modules for @guppy (#105)
Browse files Browse the repository at this point in the history
Makes `@guppy` define local modules by default. Fixes #101.

- Removes `_Guppy.set_module`
- Adds a `_Guppy.compile` that compiles and returns the local module.
- Similarly, a `_Guppy.take_module` that returns the module without
compiling it (and removes it from the local context).
- Adds a `@guppy(compile=True)` option as a fall-back to the previous
behaviour, where it directly compiles a Hugr.

Look at the first commit. The second is just noise updating the tests to
use the `compile` flag.

I wonder if it'd be better to have a different decorator that compiles
directly to hugrs instead of the `compile` flag.
  • Loading branch information
aborgna-q authored Jan 18, 2024
1 parent a42db7d commit f52a5de
Show file tree
Hide file tree
Showing 115 changed files with 449 additions and 299 deletions.
126 changes: 103 additions & 23 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import functools
import inspect
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Any, ClassVar, TypeVar

from guppylang.ast_util import AstNode, has_empty_body
Expand All @@ -12,55 +15,102 @@
DefaultCallCompiler,
OpCompiler,
)
from guppylang.error import GuppyError, pretty_errors
from guppylang.error import GuppyError, MissingModuleError, pretty_errors
from guppylang.gtypes import GuppyType, TypeTransformer
from guppylang.hugr import ops, tys
from guppylang.hugr.hugr import Hugr
from guppylang.module import GuppyModule, PyFunc, parse_py_func

FuncDecorator = Callable[[PyFunc], PyFunc]
FuncDecorator = Callable[[PyFunc], PyFunc | Hugr]
CustomFuncDecorator = Callable[[PyFunc], CustomFunction]
ClassDecorator = Callable[[type], type]


@dataclass(frozen=True)
class ModuleIdentifier:
"""Identifier for the Python file/module that called the decorator."""

filename: Path
module: ModuleType | None

@property
def name(self) -> str:
"""Returns a user-friendly name for the caller.
If the called is not a function, uses the file name.
"""
if self.module is not None:
return str(self.module.__name__)
return self.filename.name


class _Guppy:
"""Class for the `@guppy` decorator."""

# The current module
_module: GuppyModule | None
# The currently-alive GuppyModules, associated with a Python file/module.
#
# Only contains **uncompiled** modules.
_modules: dict[ModuleIdentifier, GuppyModule]

def __init__(self) -> None:
self._module = None

def set_module(self, module: GuppyModule) -> None:
self._module = module
self._modules = {}

@pretty_errors
def __call__(self, arg: PyFunc | GuppyModule) -> Hugr | None | FuncDecorator:
def __call__(self, arg: PyFunc | GuppyModule) -> FuncDecorator:
"""Decorator to annotate Python functions as Guppy code.
Optionally, the `GuppyModule` in which the function should be placed can be
passed to the decorator.
Optionally, the `GuppyModule` in which the function should be placed can
be passed to the decorator.
"""
if isinstance(arg, GuppyModule):

def make_dummy(wraps: PyFunc) -> Callable[..., Any]:
@functools.wraps(wraps)
def dummy(*args: Any, **kwargs: Any) -> Any:
raise GuppyError(
"Guppy functions can only be called in a Guppy context"
)

return dummy

if not isinstance(arg, GuppyModule):
# Decorator used without any arguments.
# We default to a module associated with the caller of the decorator.
f = arg

caller = self._get_python_caller(f)
if caller not in self._modules:
self._modules[caller] = GuppyModule(caller.name)
module = self._modules[caller]
module.register_func_def(f)
return make_dummy(f)

if isinstance(arg, GuppyModule):
# Module passed.
def dec(f: Callable[..., Any]) -> Callable[..., Any]:
assert isinstance(arg, GuppyModule)
arg.register_func_def(f)
return make_dummy(f)

@functools.wraps(f)
def dummy(*args: Any, **kwargs: Any) -> Any:
raise GuppyError(
"Guppy functions can only be called in a Guppy context"
)
return dec

return dummy
raise ValueError(f"Invalid arguments to `@guppy` decorator: {arg}")

return dec
def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier:
"""Returns an identifier for the Python file/module that called the decorator.
:param fn: Optional. The function that was decorated.
"""
if fn is not None:
filename = inspect.getfile(fn)
module = inspect.getmodule(fn)
else:
module = self._module or GuppyModule("module")
module.register_func_def(arg)
return module.compile()
for s in inspect.stack():
if s.filename != __file__:
filename = s.filename
module = inspect.getmodule(s.frame)
break
else:
raise GuppyError("Could not find a caller for the `@guppy` decorator")
return ModuleIdentifier(Path(filename), module)

@pretty_errors
def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorator:
Expand Down Expand Up @@ -208,5 +258,35 @@ def dummy(*args: Any, **kwargs: Any) -> Any:

return dec

def take_module(self, id: ModuleIdentifier | None = None) -> GuppyModule:
"""Returns the local GuppyModule, removing it from the local state."""
orig_id = id
if id is None:
id = self._get_python_caller()
if id not in self._modules:
err = (
f"Module {orig_id.name} not found."
if orig_id
else "No Guppy functions or types defined in this module."
)
raise MissingModuleError(err)
return self._modules.pop(id)

def compile_module(self, id: ModuleIdentifier | None = None) -> Hugr | None:
"""Compiles the local module into a Hugr."""
module = self.take_module(id)
if not module:
err = (
f"Module {id.name} not found."
if id
else "No Guppy functions or types defined in this module."
)
raise MissingModuleError(err)
return module.compile()

def registered_modules(self) -> list[ModuleIdentifier]:
"""Returns a list of all currently registered modules for local contexts."""
return list(self._modules.keys())


guppy = _Guppy()
8 changes: 6 additions & 2 deletions guppylang/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class GuppyTypeInferenceError(GuppyError):
"""Special Guppy exception for type inference errors."""


class MissingModuleError(GuppyError):
"""Special Guppy exception for operations that require a guppy module."""


class InternalGuppyError(Exception):
"""Exception for internal problems during compilation."""

Expand Down Expand Up @@ -166,7 +170,7 @@ def pretty_errors(f: FuncT) -> FuncT:
"""Decorator to print custom error banners when a `GuppyError` occurs."""

@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
def pretty_errors_wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return f(*args, **kwargs)
except GuppyError as err:
Expand All @@ -188,4 +192,4 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
sys.exit(1)
return None

return cast(FuncT, wrapped)
return cast(FuncT, pretty_errors_wrapped)
2 changes: 1 addition & 1 deletion tests/error/comprehension_errors/illegal_short_circuit.err
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
4: @compile_guppy
5: def foo(xs: list[int]) -> None:
6: [x for x in xs if x < 5 and x != 6]
^^^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions tests/error/comprehension_errors/illegal_short_circuit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> None:
[x for x in xs if x < 5 and x != 6]
2 changes: 1 addition & 1 deletion tests/error/comprehension_errors/illegal_ternary.err
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
4: @compile_guppy
5: def foo(xs: list[int], ys: list[int], b: bool) -> None:
6: [x for x in (xs if b else ys)]
^^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions tests/error/comprehension_errors/illegal_ternary.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int], ys: list[int], b: bool) -> None:
[x for x in (xs if b else ys)]
2 changes: 1 addition & 1 deletion tests/error/comprehension_errors/illegal_walrus.err
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
4: @compile_guppy
5: def foo(xs: list[int]) -> None:
6: [y := x for x in xs]
^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions tests/error/comprehension_errors/illegal_walrus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> None:
[y := x for x in xs]
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/and_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool, y: int) -> int:
if x and (z := y + 1):
return z
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/else_expr_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
(y := 1) if x else (z := 2)
return z
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/else_expr_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
y = 3
(y := y + 1) if x else (y := True)
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/else_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
if x:
y = 1
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/else_not_defined_functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
_@functional
if x:
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/else_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
y = 3
if x:
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/else_type_change_functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
y = 3
_@functional
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/for_new_var.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> int:
for _ in xs:
y = 5
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/for_target.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> int:
for x in xs:
pass
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/for_target_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[bool]) -> int:
x = 5
for x in xs:
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/for_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> int:
y = 5
for x in xs:
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/if_different_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
if x:
y = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
_@functional
if x:
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/if_expr_cond_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
y = 4
0 if (y := x) else (y := 6)
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/if_expr_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
(y := 1) if x else 0
return y
Loading

0 comments on commit f52a5de

Please sign in to comment.