Skip to content

Commit

Permalink
feat: Add py expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jan 3, 2024
1 parent 89654cd commit 23d140d
Show file tree
Hide file tree
Showing 21 changed files with 349 additions and 39 deletions.
23 changes: 13 additions & 10 deletions guppy/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self

from guppy.ast_util import AstNode, name_nodes_in_ast
from guppy.nodes import NestedFunctionDef
from guppy.nodes import NestedFunctionDef, PyExpr

if TYPE_CHECKING:
from guppy.cfg.cfg import BaseCFG
Expand Down Expand Up @@ -83,10 +83,9 @@ def compute_variable_stats(self) -> None:
visitor = VariableVisitor(self)
for s in self.statements:
visitor.visit(s)
self._vars = visitor.stats

if self.branch_pred is not None:
self._vars.update_used(self.branch_pred)
visitor.visit(self.branch_pred)
self._vars = visitor.stats


class VariableVisitor(ast.NodeVisitor):
Expand All @@ -99,24 +98,31 @@ def __init__(self, bb: BB):
self.bb = bb
self.stats = VariableStats()

def visit_Name(self, node: ast.Name) -> None:
self.stats.update_used(node)

def visit_Assign(self, node: ast.Assign) -> None:
self.stats.update_used(node.value)
self.visit(node.value)
for t in node.targets:
for name in name_nodes_in_ast(t):
self.stats.assigned[name.id] = node

def visit_AugAssign(self, node: ast.AugAssign) -> None:
self.stats.update_used(node.value)
self.visit(node.value)
self.stats.update_used(node.target) # The target is also used
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if node.value:
self.stats.update_used(node.value)
self.visit(node.value)
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_PyExpr(self, node: PyExpr) -> None:
# Don't look into `py(...)` expressions
pass

def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:
# In order to compute the used external variables in a nested function
# definition, we have to run live variable analysis first
Expand All @@ -139,6 +145,3 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:

# The name of the function is now assigned
self.stats.assigned[node.name] = node

def generic_visit(self, node: ast.AST) -> None:
self.stats.update_used(node)
21 changes: 19 additions & 2 deletions guppy/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from collections.abc import Iterator
from typing import NamedTuple

from guppy.ast_util import AstVisitor, set_location_from
from guppy.ast_util import AstVisitor, set_location_from, with_loc
from guppy.cfg.bb import BB, BBStatement
from guppy.cfg.cfg import CFG
from guppy.checker.core import Globals
from guppy.error import GuppyError, InternalGuppyError
from guppy.gtypes import NoneType
from guppy.nodes import NestedFunctionDef
from guppy.nodes import NestedFunctionDef, PyExpr

# In order to build expressions, need an endless stream of unique temporary variables
# to store intermediate results
Expand Down Expand Up @@ -258,6 +258,23 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
# The final value is stored in the temporary variable
return self._make_var(tmp, node)

def visit_Call(self, node: ast.Call) -> ast.AST:
# Parse compile-time evaluated `py(...)` expression
if isinstance(node.func, ast.Name) and node.func.id == "py":
match node.args:
case []:
raise GuppyError(
"Compile-time evaluated `py(...)` expression requires an "
"argument",
node,
)
case [arg]:
pass
case args:
arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load))
return with_loc(node, PyExpr(value=arg))
return self.generic_visit(node)

def generic_visit(self, node: ast.AST) -> ast.AST:
# Short-circuit expressions must be built using the `BranchBuilder`. However, we
# can turn them into regular expressions by assigning True/False to a temporary
Expand Down
53 changes: 50 additions & 3 deletions guppy/checker/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import NamedTuple
from typing import Any, NamedTuple

from guppy.ast_util import AstNode
from guppy.ast_util import AstNode, name_nodes_in_ast
from guppy.gtypes import (
BoolType,
FunctionType,
Expand Down Expand Up @@ -43,6 +43,9 @@ def synthesize_call(
"""Synthesizes the return type of a function call."""


PyScope = dict[str, Any]


class Globals(NamedTuple):
"""Collection of names that are available on module-level.
Expand All @@ -52,6 +55,7 @@ class Globals(NamedTuple):

values: dict[str, Variable]
types: dict[str, type[GuppyType]]
python_scope: PyScope

@staticmethod
def default() -> "Globals":
Expand All @@ -63,7 +67,7 @@ def default() -> "Globals":
NoneType.name: NoneType,
BoolType.name: BoolType,
}
return Globals({}, tys)
return Globals({}, tys, {})

def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None:
"""Looks up an instance function with a given name for a type.
Expand All @@ -81,6 +85,7 @@ def __or__(self, other: "Globals") -> "Globals":
return Globals(
self.values | other.values,
self.types | other.types,
self.python_scope | other.python_scope,
)

def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034
Expand All @@ -100,6 +105,48 @@ class Context(NamedTuple):
locals: Locals


class DummyEvalDict(dict[str, Any]):
"""A custom dict that can be passed to `eval` to give better error messages.
This class is used to implement the `py(...)` expression. If the user tries to
access a Guppy variable in the Python context, we give an informative error message.
"""

ctx: Context
node: ast.expr

@dataclass
class GuppyVarUsedError(BaseException):
"""Error that is raised when the user tries to access a Guppy variable."""

var: str
node: ast.Name | None

def __init__(self, ctx: Context, node: ast.expr):
super().__init__(**ctx.globals.python_scope)
self.ctx = ctx
self.node = node

def _check_item(self, key: str) -> None:
# Catch the user trying to access Guppy variables
if key in self.ctx.locals:
# Find the name node in the AST where the usage occurs
n = next((n for n in name_nodes_in_ast(self.node) if n.id == key), None)
raise self.GuppyVarUsedError(key, n)

def __getitem__(self, key: str) -> Any:
self._check_item(key)
return super().__getitem__(key)

def __delitem__(self, key: str) -> None:
self._check_item(key)
super().__delitem__(key)

def __contains__(self, key: object) -> bool:
if isinstance(key, str) and key in self.ctx.locals:
return True
return super().__contains__(key)


def qualified_name(ty: type[GuppyType] | str, name: str) -> str:
"""Returns a qualified name for an instance function on a type."""
ty_name = ty if isinstance(ty, str) else ty.name
Expand Down
42 changes: 39 additions & 3 deletions guppy/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
"""

import ast
import traceback
from contextlib import suppress
from typing import Any, NoReturn
from typing import Any, NoReturn, cast

from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type
from guppy.checker.core import CallableVariable, Context, Globals
from guppy.checker.core import CallableVariable, Context, DummyEvalDict, Globals
from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError
from guppy.gtypes import BoolType, FunctionType, GuppyType, TupleType
from guppy.nodes import GlobalName, LocalCall, LocalName
from guppy.nodes import GlobalName, LocalCall, LocalName, PyExpr

# Mapping from unary AST op to dunder method and display name
unary_table: dict[type[ast.unaryop], tuple[str, str]] = {
Expand Down Expand Up @@ -259,6 +260,36 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]:
else:
raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func)

def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]:
try:
python_val = eval( # noqa: S307, PGH001
ast.unparse(node.value),
None,
DummyEvalDict(self.ctx, node.value),
)
except DummyEvalDict.GuppyVarUsedError as e:
raise GuppyError(
f"Guppy variable `{e.var}` cannot be accessed in a compile-time "
"evaluated `py(...)` expression",
e.node or node,
) from None
except Exception as e: # noqa: BLE001
# Remove the top frame pointing to the `eval` call from the stack trace
tb = e.__traceback__.tb_next if e.__traceback__ else None
raise GuppyError(
"Error occurred while evaluating Python expression:\n\n"
+ "".join(traceback.format_exception(type(e), e, tb)),
node,
) from e

if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals):
return with_loc(node, ast.Constant(value=python_val)), ty

raise GuppyError(
f"Python expression of type `{type(python_val)}` is not supported by Guppy",
node,
)

def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]:
raise InternalGuppyError(
"BB contains `NamedExpr`. Should have been removed during CFG"
Expand Down Expand Up @@ -360,5 +391,10 @@ def python_value_to_guppy_type(
return globals.types["int"].build(node=node)
case float():
return globals.types["float"].build(node=node)
case tuple(elts):
tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts]
if any(ty is None for ty in tys):
return None
return TupleType(cast(list[GuppyType], tys))
case _:
return None
2 changes: 1 addition & 1 deletion guppy/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def check_nested_func_def(
if not captured:
# If there are no captured vars, we treat the function like a global name
func = DefinedFunction(func_def.name, func_ty, func_def, None)
globals = ctx.globals | Globals({func_def.name: func}, {})
globals = ctx.globals | Globals({func_def.name: func}, {}, {})

else:
# Otherwise, we treat it like a local name
Expand Down
21 changes: 14 additions & 7 deletions guppy/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,17 @@ def python_value_to_hugr(v: Any) -> val.Value | None:
"""
from guppy.prelude._internal import bool_value, float_value, int_value

if isinstance(v, bool):
return bool_value(v)
elif isinstance(v, int):
return int_value(v)
elif isinstance(v, float):
return float_value(v)
return None
match v:
case bool():
return bool_value(v)
case int():
return int_value(v)
case float():
return float_value(v)
case tuple(elts):
vs = [python_value_to_hugr(elt) for elt in elts]
if any(value is None for value in vs):
return None
return val.Tuple(vs=vs)
case _:
return None
44 changes: 40 additions & 4 deletions guppy/decorator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
import inspect
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

from guppy.ast_util import AstNode, has_empty_body
from guppy.checker.core import PyScope
from guppy.custom import (
CustomCallChecker,
CustomCallCompiler,
Expand Down Expand Up @@ -46,7 +48,7 @@ def __call__(self, arg: PyFunc | GuppyModule) -> Hugr | None | FuncDecorator:

def dec(f: Callable[..., Any]) -> Callable[..., Any]:
assert isinstance(arg, GuppyModule)
arg.register_func_def(f)
arg.register_func_def(f, get_python_scope())

@functools.wraps(f)
def dummy(*args: Any, **kwargs: Any) -> Any:
Expand All @@ -59,7 +61,7 @@ def dummy(*args: Any, **kwargs: Any) -> Any:
return dec
else:
module = self._module or GuppyModule("module")
module.register_func_def(arg)
module.register_func_def(arg, get_python_scope())
return module.compile()

@pretty_errors
Expand All @@ -68,7 +70,7 @@ def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorato
module._instance_func_buffer = {}

def dec(c: type) -> type:
module._register_buffered_instance_funcs(ty)
module._register_buffered_instance_funcs(ty, get_python_scope())
return c

return dec
Expand Down Expand Up @@ -118,7 +120,7 @@ def __str__(self) -> str:
NewType.__name__ = name
NewType.__qualname__ = _name
module.register_type(_name, NewType)
module._register_buffered_instance_funcs(NewType)
module._register_buffered_instance_funcs(NewType, get_python_scope())
setattr(c, "_guppy_type", NewType)
return c

Expand Down Expand Up @@ -190,3 +192,37 @@ def dummy(*args: Any, **kwargs: Any) -> Any:


guppy = _Guppy()


def get_python_scope() -> PyScope:
"""Looks up all available Python variables from the call-site.
Walks up the call stack until we have left the `guppy` module.
"""
# Note that this approach will yield unintended results if the user doesn't invoke
# the decorator directly. For example:
#
# def my_dec(f):
# some_local = ...
# return guppy(f)
#
# @my_dec
# def guppy_func(x: int) -> int:
# ....
#
# Here, we would reach the scope of `my_dec` and `some_local` would be available
# in the Guppy code.
# TODO: Is there a better way to obtain the variables in scope? Note that we
# could do `inspect.getclosurevars(f)` but it will fail if `f` has recursive
# calls. A custom solution based on `f.__code__.co_freevars` and
# `f.__closure__` would only work for CPython.
frame = inspect.currentframe()
if frame is None:
return {}
while frame.f_back is not None and frame.f_globals["__name__"].startswith("guppy."):
frame = frame.f_back
py_scope = frame.f_globals | frame.f_locals
# Explicitly delete frame to avoid reference cycle.
# See https://docs.python.org/3/library/inspect.html#the-interpreter-stack
del frame
return py_scope
Loading

0 comments on commit 23d140d

Please sign in to comment.