Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Local implicit modules for @guppy #105

Merged
merged 16 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 107 additions & 23 deletions guppy/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools
import inspect
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any, ClassVar, TypeVar

from guppy.ast_util import AstNode, has_empty_body
Expand All @@ -18,49 +20,112 @@
from guppy.hugr.hugr import Hugr
from guppy.module import GuppyModule, PyFunc, parse_py_func

FuncDecorator = Callable[[PyFunc], PyFunc]
FuncDecorator = Callable[[PyFunc], PyFunc | Hugr]
CustomFuncDecorator = Callable[[PyFunc], CustomFunction]
ClassDecorator = Callable[[type], type]


@dataclass
class CallerIdentifier:
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
"""Identifier for the interpreter frame that called the decorator."""

filename: Path
function: str

@property
def name(self) -> str:
"""Returns a user-friendly name for the caller.

If the called is not a function, uses the file name.
"""
if self.function == "<module>":
return self.filename.name
return self.function

def __hash__(self) -> int:
return hash((self.filename, self.function))


class _Guppy:
"""Class for the `@guppy` decorator."""

# The current module
_module: GuppyModule | None
# The currently-alive modules, associated with an element in the call stack.
_modules: dict[CallerIdentifier, GuppyModule]

def __init__(self) -> None:
self._module = None

def set_module(self, module: GuppyModule) -> None:
self._module = module
self._modules = {}

@pretty_errors
def __call__(self, arg: PyFunc | GuppyModule) -> Hugr | None | FuncDecorator:
def __call__(
self, arg: PyFunc | GuppyModule | None = None, *, compile: bool = False
) -> Hugr | None | FuncDecorator:
"""Decorator to annotate Python functions as Guppy code.

Optionally, the `GuppyModule` in which the function should be placed can be
passed to the decorator.
Optionally, the `GuppyModule` in which the function should be placed can
be passed to the decorator.

If `compile` is set to `True` and no `GuppyModule` is passed, the
function is compiled immediately as an standalone module and the Hugr is
returned.
"""
if arg is not None and not isinstance(arg, GuppyModule):
# Decorator used without any arguments.
f = arg
decorator: FuncDecorator = self.__call__(None) # type: ignore[assignment]
return decorator(f)

def make_dummy(wraps: PyFunc) -> Callable[..., Any]:
@functools.wraps(wraps)
def dummy(*args: Any, **kwargs: Any) -> Any:
raise GuppyError(
"Guppy functions can only be called in a Guppy context"
)

return dummy

if arg is None and compile:
# No module passed, and compile option is set.
def dec(f: Callable[..., Any]) -> Callable[..., Any] | Hugr:
module = GuppyModule("module")
module.register_func_def(f)
compiled = module.compile()
assert compiled is not None
return compiled

return dec

if arg is None and not compile:
# No module specified, and `compile` option is not set.
# We use a module associate with the caller of the decorator.
def dec(f: Callable[..., Any]) -> Callable[..., Any] | Hugr:
caller = self._get_python_caller()
if caller not in self._modules:
self._modules[caller] = GuppyModule(caller.name)
module = self._modules[caller]
module.register_func_def(f)
return make_dummy(f)

return dec

if isinstance(arg, GuppyModule):
# Module passed. Ignore `compile` option.

def dec(f: Callable[..., Any]) -> Callable[..., Any]:
assert isinstance(arg, GuppyModule)
def dec(f: Callable[..., Any]) -> Callable[..., Any] | Hugr:
arg.register_func_def(f)
return make_dummy(f)

@functools.wraps(f)
def dummy(*args: Any, **kwargs: Any) -> Any:
raise GuppyError(
"Guppy functions can only be called in a Guppy context"
)
return dec

return dummy
raise ValueError(f"Invalid arguments to `@guppy` decorator: {arg}")

return dec
else:
module = self._module or GuppyModule("module")
module.register_func_def(arg)
return module.compile()
def _get_python_caller(self) -> CallerIdentifier:
"""Returns an identifier for the interpreter frame that called the decorator."""
for s in inspect.stack():
# Note the hacky check for the pretty errors wrapper,
# since @pretty_errors wraps the __call__ method.
if s.filename != __file__ and s.function != "pretty_errors_wrapped":
return CallerIdentifier(Path(s.filename), s.function)
raise GuppyError("Could not find caller of `@guppy` decorator")
mark-koch marked this conversation as resolved.
Show resolved Hide resolved

@pretty_errors
def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorator:
Expand Down Expand Up @@ -204,5 +269,24 @@ def dummy(*args: Any, **kwargs: Any) -> Any:

return dec

def take_module(self, id: CallerIdentifier | None = None) -> GuppyModule | None:
"""Returns the local GuppyModule, removing it from the local state."""
if id is None:
id = self._get_python_caller()
if id not in self._modules:
return None
module = self._modules[id]
del self._modules[id]
return module
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved

def compile(self, id: CallerIdentifier | None = None) -> Hugr | None:
"""Compiles the local module into a Hugr."""
module = self.take_module(id)
return module.compile() if module else None
mark-koch marked this conversation as resolved.
Show resolved Hide resolved

def registered_modules(self) -> list[CallerIdentifier]:
"""Returns a list of all currently registered modules for local contexts."""
return list(self._modules.keys())


guppy = _Guppy()
4 changes: 2 additions & 2 deletions guppy/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def pretty_errors(f: FuncT) -> FuncT:
"""Decorator to print custom error banners when a `GuppyError` occurs."""

@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
def pretty_errors_wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return f(*args, **kwargs)
except GuppyError as err:
Expand All @@ -188,4 +188,4 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
sys.exit(1)
return None

return cast(FuncT, wrapped)
return cast(FuncT, pretty_errors_wrapped)
36 changes: 18 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/and_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool, y: int) -> int:
if x and (z := y + 1):
return z
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/else_expr_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
(y := 1) if x else (z := 2)
return z
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/else_expr_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
y = 3
(y := y + 1) if x else (y := True)
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/else_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
if x:
y = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
_@functional
if x:
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/else_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
y = 3
if x:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
y = 3
_@functional
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_different_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
if x:
y = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
_@functional
if x:
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_expr_cond_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
y = 4
0 if (y := x) else (y := 6)
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_expr_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
(y := 1) if x else 0
return y
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_expr_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool, a: int) -> int:
y = 3
(y := False) if x or a > 5 else 0
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_expr_type_conflict.err
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
4: @guppy(compile=True)
5: def foo(x: bool) -> None:
6: y = True if x else 42
^^^^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_expr_type_conflict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> None:
y = True if x else 42
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
if x:
y = 1
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_not_defined_functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
_@functional
if x:
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
y = 3
if x:
Expand Down
2 changes: 1 addition & 1 deletion tests/error/errors_on_usage/if_type_change_functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppy.decorator import guppy


@guppy
@guppy(compile=True)
def foo(x: bool) -> int:
y = 3
_@functional
Expand Down
Loading