Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Local implicit modules for @guppy #105

Merged
merged 16 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 106 additions & 23 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import functools
import inspect
from collections.abc import Callable, Iterator, Sequence
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Any, ClassVar, TypeVar

from guppylang.ast_util import AstNode, has_empty_body
Expand All @@ -12,55 +15,105 @@
DefaultCallCompiler,
OpCompiler,
)
from guppylang.error import GuppyError, pretty_errors
from guppylang.error import GuppyError, MissingModuleError, pretty_errors
from guppylang.gtypes import GuppyType, TypeTransformer
from guppylang.hugr import ops, tys
from guppylang.hugr.hugr import Hugr
from guppylang.module import GuppyModule, PyFunc, parse_py_func

FuncDecorator = Callable[[PyFunc], PyFunc]
FuncDecorator = Callable[[PyFunc], PyFunc | Hugr]
CustomFuncDecorator = Callable[[PyFunc], CustomFunction]
ClassDecorator = Callable[[type], type]


@dataclass(frozen=True)
class CallerIdentifier:
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
"""Identifier for the interpreter frame that called the decorator."""
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved

filename: Path
module: ModuleType | None

@property
def name(self) -> str:
"""Returns a user-friendly name for the caller.

If the called is not a function, uses the file name.
"""
if self.module is not None:
return str(self.module.__name__)
return self.filename.name


class _Guppy:
"""Class for the `@guppy` decorator."""

# The current module
_module: GuppyModule | None
# The currently-alive modules, associated with an element in the call stack.
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
#
# Only contains **uncompiled** modules.
_modules: dict[CallerIdentifier, GuppyModule]

def __init__(self) -> None:
self._module = None

def set_module(self, module: GuppyModule) -> None:
self._module = module
self._modules = {}

@pretty_errors
def __call__(self, arg: PyFunc | GuppyModule) -> Hugr | None | FuncDecorator:
def __call__(self, arg: PyFunc | GuppyModule) -> FuncDecorator:
"""Decorator to annotate Python functions as Guppy code.

Optionally, the `GuppyModule` in which the function should be placed can be
passed to the decorator.
Optionally, the `GuppyModule` in which the function should be placed can
be passed to the decorator.
"""
if isinstance(arg, GuppyModule):

def make_dummy(wraps: PyFunc) -> Callable[..., Any]:
@functools.wraps(wraps)
def dummy(*args: Any, **kwargs: Any) -> Any:
raise GuppyError(
"Guppy functions can only be called in a Guppy context"
)

return dummy

if not isinstance(arg, GuppyModule):
# Decorator used without any arguments.
# We default to a module associated with the caller of the decorator.
f = arg

def dec(f: Callable[..., Any]) -> Callable[..., Any]:
assert isinstance(arg, GuppyModule)
arg.register_func_def(f)
caller = self._get_python_caller(f)
if caller not in self._modules:
self._modules[caller] = GuppyModule(caller.name)
module = self._modules[caller]
module.register_func_def(f)
return make_dummy(f)

@functools.wraps(f)
def dummy(*args: Any, **kwargs: Any) -> Any:
raise GuppyError(
"Guppy functions can only be called in a Guppy context"
)
return dec(f)
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved

return dummy
if isinstance(arg, GuppyModule):
# Module passed.
def dec(f: Callable[..., Any]) -> Callable[..., Any]:
arg.register_func_def(f)
return make_dummy(f)

return dec

raise ValueError(f"Invalid arguments to `@guppy` decorator: {arg}")

def _get_python_caller(self, fn: PyFunc | None = None) -> CallerIdentifier:
"""Returns an identifier for the interpreter frame that called the decorator.
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved

:param fn: Optional. The function that was decorated.
"""
if fn is not None:
filename = inspect.getfile(fn)
module = inspect.getmodule(fn)
else:
module = self._module or GuppyModule("module")
module.register_func_def(arg)
return module.compile()
for s in inspect.stack():
if s.filename != __file__:
filename = s.filename
module = inspect.getmodule(s.frame)
break
else:
raise GuppyError("Could not find a caller for the `@guppy` decorator")
return CallerIdentifier(Path(filename), module)

@pretty_errors
def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorator:
Expand Down Expand Up @@ -208,5 +261,35 @@ def dummy(*args: Any, **kwargs: Any) -> Any:

return dec

def take_module(self, id: CallerIdentifier | None = None) -> GuppyModule:
"""Returns the local GuppyModule, removing it from the local state."""
orig_id = id
if id is None:
id = self._get_python_caller()
if id not in self._modules:
err = (
f"Module {orig_id.name} not found."
if orig_id
else "No Guppy functions or types defined in this module."
)
raise MissingModuleError(err)
return self._modules.pop(id)

def compile_module(self, id: CallerIdentifier | None = None) -> Hugr | None:
"""Compiles the local module into a Hugr."""
module = self.take_module(id)
if not module:
err = (
f"Module {id.name} not found."
if id
else "No Guppy functions or types defined in this module."
)
raise MissingModuleError(err)
return module.compile()

def registered_modules(self) -> list[CallerIdentifier]:
"""Returns a list of all currently registered modules for local contexts."""
return list(self._modules.keys())


guppy = _Guppy()
8 changes: 6 additions & 2 deletions guppylang/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class GuppyTypeInferenceError(GuppyError):
"""Special Guppy exception for type inference errors."""


class MissingModuleError(GuppyError):
"""Special Guppy exception for operations that require a guppy module."""


class InternalGuppyError(Exception):
"""Exception for internal problems during compilation."""

Expand Down Expand Up @@ -166,7 +170,7 @@ def pretty_errors(f: FuncT) -> FuncT:
"""Decorator to print custom error banners when a `GuppyError` occurs."""

@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
def pretty_errors_wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return f(*args, **kwargs)
except GuppyError as err:
Expand All @@ -188,4 +192,4 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
sys.exit(1)
return None

return cast(FuncT, wrapped)
return cast(FuncT, pretty_errors_wrapped)
5 changes: 5 additions & 0 deletions guppylang/hugr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Python definitions for the Hierarchical Unified Graph Representation."""

from .hugr import Hugr

__all__ = ["Hugr"]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
4: @compile_guppy
5: def foo(xs: list[int]) -> None:
6: [x for x in xs if x < 5 and x != 6]
^^^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions tests/error/comprehension_errors/illegal_short_circuit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> None:
[x for x in xs if x < 5 and x != 6]
2 changes: 1 addition & 1 deletion tests/error/comprehension_errors/illegal_ternary.err
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
4: @compile_guppy
5: def foo(xs: list[int], ys: list[int], b: bool) -> None:
6: [x for x in (xs if b else ys)]
^^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions tests/error/comprehension_errors/illegal_ternary.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int], ys: list[int], b: bool) -> None:
[x for x in (xs if b else ys)]
2 changes: 1 addition & 1 deletion tests/error/comprehension_errors/illegal_walrus.err
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
4: @compile_guppy
5: def foo(xs: list[int]) -> None:
6: [y := x for x in xs]
^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions tests/error/comprehension_errors/illegal_walrus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> None:
[y := x for x in xs]
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/and_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool, y: int) -> int:
if x and (z := y + 1):
return z
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/else_expr_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
(y := 1) if x else (z := 2)
return z
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/else_expr_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
y = 3
(y := y + 1) if x else (y := True)
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/else_not_defined.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
if x:
y = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
_@functional
if x:
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/else_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
y = 3
if x:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
y = 3
_@functional
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/for_new_var.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> int:
for _ in xs:
y = 5
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/for_target.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> int:
for x in xs:
pass
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/for_target_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[bool]) -> int:
x = 5
for x in xs:
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/for_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(xs: list[int]) -> int:
y = 5
for x in xs:
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/if_different_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
if x:
y = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
_@functional
if x:
Expand Down
4 changes: 2 additions & 2 deletions tests/error/errors_on_usage/if_expr_cond_type_change.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from guppylang.decorator import guppy
from tests.util import compile_guppy


@guppy
@compile_guppy
def foo(x: bool) -> int:
y = 4
0 if (y := x) else (y := 6)
Expand Down
Loading