Skip to content

Commit

Permalink
Merge branch 'main' into feat/hugr-json
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q authored Jan 15, 2024
2 parents 8db2a05 + 7fac597 commit fb561eb
Show file tree
Hide file tree
Showing 23 changed files with 356 additions and 33 deletions.
23 changes: 13 additions & 10 deletions guppy/cfg/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self

from guppy.ast_util import AstNode, name_nodes_in_ast
from guppy.nodes import NestedFunctionDef
from guppy.nodes import NestedFunctionDef, PyExpr

if TYPE_CHECKING:
from guppy.cfg.cfg import BaseCFG
Expand Down Expand Up @@ -83,10 +83,9 @@ def compute_variable_stats(self) -> None:
visitor = VariableVisitor(self)
for s in self.statements:
visitor.visit(s)
self._vars = visitor.stats

if self.branch_pred is not None:
self._vars.update_used(self.branch_pred)
visitor.visit(self.branch_pred)
self._vars = visitor.stats


class VariableVisitor(ast.NodeVisitor):
Expand All @@ -99,24 +98,31 @@ def __init__(self, bb: BB):
self.bb = bb
self.stats = VariableStats()

def visit_Name(self, node: ast.Name) -> None:
self.stats.update_used(node)

def visit_Assign(self, node: ast.Assign) -> None:
self.stats.update_used(node.value)
self.visit(node.value)
for t in node.targets:
for name in name_nodes_in_ast(t):
self.stats.assigned[name.id] = node

def visit_AugAssign(self, node: ast.AugAssign) -> None:
self.stats.update_used(node.value)
self.visit(node.value)
self.stats.update_used(node.target) # The target is also used
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if node.value:
self.stats.update_used(node.value)
self.visit(node.value)
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_PyExpr(self, node: PyExpr) -> None:
# Don't look into `py(...)` expressions
pass

def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:
# In order to compute the used external variables in a nested function
# definition, we have to run live variable analysis first
Expand All @@ -139,6 +145,3 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:

# The name of the function is now assigned
self.stats.assigned[node.name] = node

def generic_visit(self, node: ast.AST) -> None:
self.stats.update_used(node)
20 changes: 18 additions & 2 deletions guppy/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from collections.abc import Iterator
from typing import NamedTuple

from guppy.ast_util import AstVisitor, set_location_from
from guppy.ast_util import AstVisitor, set_location_from, with_loc
from guppy.cfg.bb import BB, BBStatement
from guppy.cfg.cfg import CFG
from guppy.checker.core import Globals
from guppy.error import GuppyError, InternalGuppyError
from guppy.gtypes import NoneType
from guppy.nodes import NestedFunctionDef
from guppy.nodes import NestedFunctionDef, PyExpr

# In order to build expressions, need an endless stream of unique temporary variables
# to store intermediate results
Expand Down Expand Up @@ -258,6 +258,22 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
# The final value is stored in the temporary variable
return self._make_var(tmp, node)

def visit_Call(self, node: ast.Call) -> ast.AST:
# Parse compile-time evaluated `py(...)` expression
if isinstance(node.func, ast.Name) and node.func.id == "py":
match node.args:
case []:
raise GuppyError(
"Compile-time `py(...)` expression requires an argument",
node,
)
case [arg]:
pass
case args:
arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load))
return with_loc(node, PyExpr(value=arg))
return self.generic_visit(node)

def generic_visit(self, node: ast.AST) -> ast.AST:
# Short-circuit expressions must be built using the `BranchBuilder`. However, we
# can turn them into regular expressions by assigning True/False to a temporary
Expand Down
53 changes: 50 additions & 3 deletions guppy/checker/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import NamedTuple
from typing import Any, NamedTuple

from guppy.ast_util import AstNode
from guppy.ast_util import AstNode, name_nodes_in_ast
from guppy.gtypes import (
BoolType,
FunctionType,
Expand Down Expand Up @@ -52,6 +52,9 @@ class TypeVarDecl:
linear: bool


PyScope = dict[str, Any]


class Globals(NamedTuple):
"""Collection of names that are available on module-level.
Expand All @@ -62,6 +65,7 @@ class Globals(NamedTuple):
values: dict[str, Variable]
types: dict[str, type[GuppyType]]
type_vars: dict[str, TypeVarDecl]
python_scope: PyScope

@staticmethod
def default() -> "Globals":
Expand All @@ -73,7 +77,7 @@ def default() -> "Globals":
NoneType.name: NoneType,
BoolType.name: BoolType,
}
return Globals({}, tys, {})
return Globals({}, tys, {}, {})

def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None:
"""Looks up an instance function with a given name for a type.
Expand All @@ -92,6 +96,7 @@ def __or__(self, other: "Globals") -> "Globals":
self.values | other.values,
self.types | other.types,
self.type_vars | other.type_vars,
self.python_scope | other.python_scope,
)

def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034
Expand All @@ -112,6 +117,48 @@ class Context(NamedTuple):
locals: Locals


class DummyEvalDict(PyScope):
"""A custom dict that can be passed to `eval` to give better error messages.
This class is used to implement the `py(...)` expression. If the user tries to
access a Guppy variable in the Python context, we give an informative error message.
"""

ctx: Context
node: ast.expr

@dataclass
class GuppyVarUsedError(BaseException):
"""Error that is raised when the user tries to access a Guppy variable."""

var: str
node: ast.Name | None

def __init__(self, ctx: Context, node: ast.expr):
super().__init__(**ctx.globals.python_scope)
self.ctx = ctx
self.node = node

def _check_item(self, key: str) -> None:
# Catch the user trying to access Guppy variables
if key in self.ctx.locals:
# Find the name node in the AST where the usage occurs
n = next((n for n in name_nodes_in_ast(self.node) if n.id == key), None)
raise self.GuppyVarUsedError(key, n)

def __getitem__(self, key: str) -> Any:
self._check_item(key)
return super().__getitem__(key)

def __delitem__(self, key: str) -> None:
self._check_item(key)
super().__delitem__(key)

def __contains__(self, key: object) -> bool:
if isinstance(key, str) and key in self.ctx.locals:
return True
return super().__contains__(key)


def qualified_name(ty: type[GuppyType] | str, name: str) -> str:
"""Returns a qualified name for an instance function on a type."""
ty_name = ty if isinstance(ty, str) else ty.name
Expand Down
50 changes: 47 additions & 3 deletions guppy/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
"""

import ast
import sys
import traceback
from contextlib import suppress
from typing import Any, NoReturn
from typing import Any, NoReturn, cast

from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type
from guppy.checker.core import CallableVariable, Context, Globals
from guppy.checker.core import CallableVariable, Context, DummyEvalDict, Globals
from guppy.error import (
GuppyError,
GuppyTypeError,
Expand All @@ -42,7 +44,7 @@
TupleType,
unify,
)
from guppy.nodes import GlobalName, LocalCall, LocalName, TypeApply
from guppy.nodes import GlobalName, LocalCall, LocalName, PyExpr, TypeApply

# Mapping from unary AST op to dunder method and display name
unary_table: dict[type[ast.unaryop], tuple[str, str]] = {
Expand Down Expand Up @@ -324,6 +326,43 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]:
else:
raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func)

def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]:
# The method we used for obtaining the Python variables in scope only works in
# CPython (see `get_py_scope()`).
if sys.implementation.name != "cpython":
raise GuppyError(
"Compile-time `py(...)` expressions are only supported in CPython", node
)

try:
python_val = eval( # noqa: S307, PGH001
ast.unparse(node.value),
None,
DummyEvalDict(self.ctx, node.value),
)
except DummyEvalDict.GuppyVarUsedError as e:
raise GuppyError(
f"Guppy variable `{e.var}` cannot be accessed in a compile-time "
"`py(...)` expression",
e.node or node,
) from None
except Exception as e: # noqa: BLE001
# Remove the top frame pointing to the `eval` call from the stack trace
tb = e.__traceback__.tb_next if e.__traceback__ else None
raise GuppyError(
"Error occurred while evaluating Python expression:\n\n"
+ "".join(traceback.format_exception(type(e), e, tb)),
node,
) from e

if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals):
return with_loc(node, ast.Constant(value=python_val)), ty

raise GuppyError(
f"Python expression of type `{type(python_val)}` is not supported by Guppy",
node,
)

def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]:
raise InternalGuppyError(
"BB contains `NamedExpr`. Should have been removed during CFG"
Expand Down Expand Up @@ -614,5 +653,10 @@ def python_value_to_guppy_type(
return globals.types["int"].build(node=node)
case float():
return globals.types["float"].build(node=node)
case tuple(elts):
tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts]
if any(ty is None for ty in tys):
return None
return TupleType(cast(list[GuppyType], tys))
case _:
return None
2 changes: 1 addition & 1 deletion guppy/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def check_nested_func_def(
if not captured:
# If there are no captured vars, we treat the function like a global name
func = DefinedFunction(func_def.name, func_ty, func_def, None)
globals = ctx.globals | Globals({func_def.name: func}, {}, {})
globals = ctx.globals | Globals({func_def.name: func}, {}, {}, {})

else:
# Otherwise, we treat it like a local name
Expand Down
21 changes: 14 additions & 7 deletions guppy/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,17 @@ def python_value_to_hugr(v: Any) -> val.Value | None:
"""
from guppy.prelude._internal import bool_value, float_value, int_value

if isinstance(v, bool):
return bool_value(v)
elif isinstance(v, int):
return int_value(v)
elif isinstance(v, float):
return float_value(v)
return None
match v:
case bool():
return bool_value(v)
case int():
return int_value(v)
case float():
return float_value(v)
case tuple(elts):
vs = [python_value_to_hugr(elt) for elt in elts]
if any(value is None for value in vs):
return None
return val.Tuple(vs=vs)
case _:
return None
2 changes: 1 addition & 1 deletion guppy/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class BaseOp(ABC, BaseModel):

# Parent node index of node the op belongs to, used only at serialization time
parent: NodeID = 0
input_extensions: ExtensionSet = Field(default_factory=list)
input_extensions: ExtensionSet | None = None

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
"""Hook to insert type information from the input and output ports into the
Expand Down
Loading

0 comments on commit fb561eb

Please sign in to comment.