From 23d140d6d41ea11468cf593745d3d56dc3b704d0 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 3 Jan 2024 16:39:27 +0100 Subject: [PATCH] feat: Add py expressions --- guppy/cfg/bb.py | 23 +++++----- guppy/cfg/builder.py | 21 ++++++++- guppy/checker/core.py | 53 +++++++++++++++++++++-- guppy/checker/expr_checker.py | 42 ++++++++++++++++-- guppy/checker/func_checker.py | 2 +- guppy/compiler/expr_compiler.py | 21 ++++++--- guppy/decorator.py | 44 +++++++++++++++++-- guppy/module.py | 25 +++++++---- guppy/nodes.py | 8 ++++ tests/error/py_errors/__init__.py | 0 tests/error/py_errors/guppy_name1.err | 7 +++ tests/error/py_errors/guppy_name1.py | 6 +++ tests/error/py_errors/guppy_name2.err | 7 +++ tests/error/py_errors/guppy_name2.py | 9 ++++ tests/error/py_errors/no_args.err | 7 +++ tests/error/py_errors/no_args.py | 6 +++ tests/error/py_errors/python_err.err | 12 ++++++ tests/error/py_errors/python_err.py | 6 +++ tests/error/test_py_errors.py | 20 +++++++++ tests/integration/test_py.py | 62 +++++++++++++++++++++++++++ tests/integration/util.py | 7 +++ 21 files changed, 349 insertions(+), 39 deletions(-) create mode 100644 tests/error/py_errors/__init__.py create mode 100644 tests/error/py_errors/guppy_name1.err create mode 100644 tests/error/py_errors/guppy_name1.py create mode 100644 tests/error/py_errors/guppy_name2.err create mode 100644 tests/error/py_errors/guppy_name2.py create mode 100644 tests/error/py_errors/no_args.err create mode 100644 tests/error/py_errors/no_args.py create mode 100644 tests/error/py_errors/python_err.err create mode 100644 tests/error/py_errors/python_err.py create mode 100644 tests/error/test_py_errors.py create mode 100644 tests/integration/test_py.py diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index 99fd61a2..a91965ca 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -6,7 +6,7 @@ from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast -from guppy.nodes import NestedFunctionDef +from guppy.nodes import NestedFunctionDef, PyExpr if TYPE_CHECKING: from guppy.cfg.cfg import BaseCFG @@ -83,10 +83,9 @@ def compute_variable_stats(self) -> None: visitor = VariableVisitor(self) for s in self.statements: visitor.visit(s) - self._vars = visitor.stats - if self.branch_pred is not None: - self._vars.update_used(self.branch_pred) + visitor.visit(self.branch_pred) + self._vars = visitor.stats class VariableVisitor(ast.NodeVisitor): @@ -99,24 +98,31 @@ def __init__(self, bb: BB): self.bb = bb self.stats = VariableStats() + def visit_Name(self, node: ast.Name) -> None: + self.stats.update_used(node) + def visit_Assign(self, node: ast.Assign) -> None: - self.stats.update_used(node.value) + self.visit(node.value) for t in node.targets: for name in name_nodes_in_ast(t): self.stats.assigned[name.id] = node def visit_AugAssign(self, node: ast.AugAssign) -> None: - self.stats.update_used(node.value) + self.visit(node.value) self.stats.update_used(node.target) # The target is also used for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node def visit_AnnAssign(self, node: ast.AnnAssign) -> None: if node.value: - self.stats.update_used(node.value) + self.visit(node.value) for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node + def visit_PyExpr(self, node: PyExpr) -> None: + # Don't look into `py(...)` expressions + pass + def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # In order to compute the used external variables in a nested function # definition, we have to run live variable analysis first @@ -139,6 +145,3 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # The name of the function is now assigned self.stats.assigned[node.name] = node - - def generic_visit(self, node: ast.AST) -> None: - self.stats.update_used(node) diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index 878437bf..d2cee19d 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -3,13 +3,13 @@ from collections.abc import Iterator from typing import NamedTuple -from guppy.ast_util import AstVisitor, set_location_from +from guppy.ast_util import AstVisitor, set_location_from, with_loc from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import NoneType -from guppy.nodes import NestedFunctionDef +from guppy.nodes import NestedFunctionDef, PyExpr # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -258,6 +258,23 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: # The final value is stored in the temporary variable return self._make_var(tmp, node) + def visit_Call(self, node: ast.Call) -> ast.AST: + # Parse compile-time evaluated `py(...)` expression + if isinstance(node.func, ast.Name) and node.func.id == "py": + match node.args: + case []: + raise GuppyError( + "Compile-time evaluated `py(...)` expression requires an " + "argument", + node, + ) + case [arg]: + pass + case args: + arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) + return with_loc(node, PyExpr(value=arg)) + return self.generic_visit(node) + def generic_visit(self, node: ast.AST) -> ast.AST: # Short-circuit expressions must be built using the `BranchBuilder`. However, we # can turn them into regular expressions by assigning True/False to a temporary diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 25dd256d..37d8b5be 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -1,9 +1,9 @@ import ast from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import NamedTuple +from typing import Any, NamedTuple -from guppy.ast_util import AstNode +from guppy.ast_util import AstNode, name_nodes_in_ast from guppy.gtypes import ( BoolType, FunctionType, @@ -43,6 +43,9 @@ def synthesize_call( """Synthesizes the return type of a function call.""" +PyScope = dict[str, Any] + + class Globals(NamedTuple): """Collection of names that are available on module-level. @@ -52,6 +55,7 @@ class Globals(NamedTuple): values: dict[str, Variable] types: dict[str, type[GuppyType]] + python_scope: PyScope @staticmethod def default() -> "Globals": @@ -63,7 +67,7 @@ def default() -> "Globals": NoneType.name: NoneType, BoolType.name: BoolType, } - return Globals({}, tys) + return Globals({}, tys, {}) def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None: """Looks up an instance function with a given name for a type. @@ -81,6 +85,7 @@ def __or__(self, other: "Globals") -> "Globals": return Globals( self.values | other.values, self.types | other.types, + self.python_scope | other.python_scope, ) def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034 @@ -100,6 +105,48 @@ class Context(NamedTuple): locals: Locals +class DummyEvalDict(dict[str, Any]): + """A custom dict that can be passed to `eval` to give better error messages. + This class is used to implement the `py(...)` expression. If the user tries to + access a Guppy variable in the Python context, we give an informative error message. + """ + + ctx: Context + node: ast.expr + + @dataclass + class GuppyVarUsedError(BaseException): + """Error that is raised when the user tries to access a Guppy variable.""" + + var: str + node: ast.Name | None + + def __init__(self, ctx: Context, node: ast.expr): + super().__init__(**ctx.globals.python_scope) + self.ctx = ctx + self.node = node + + def _check_item(self, key: str) -> None: + # Catch the user trying to access Guppy variables + if key in self.ctx.locals: + # Find the name node in the AST where the usage occurs + n = next((n for n in name_nodes_in_ast(self.node) if n.id == key), None) + raise self.GuppyVarUsedError(key, n) + + def __getitem__(self, key: str) -> Any: + self._check_item(key) + return super().__getitem__(key) + + def __delitem__(self, key: str) -> None: + self._check_item(key) + super().__delitem__(key) + + def __contains__(self, key: object) -> bool: + if isinstance(key, str) and key in self.ctx.locals: + return True + return super().__contains__(key) + + def qualified_name(ty: type[GuppyType] | str, name: str) -> str: """Returns a qualified name for an instance function on a type.""" ty_name = ty if isinstance(ty, str) else ty.name diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 731d8f7c..d9955653 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -21,14 +21,15 @@ """ import ast +import traceback from contextlib import suppress -from typing import Any, NoReturn +from typing import Any, NoReturn, cast from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type -from guppy.checker.core import CallableVariable, Context, Globals +from guppy.checker.core import CallableVariable, Context, DummyEvalDict, Globals from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError from guppy.gtypes import BoolType, FunctionType, GuppyType, TupleType -from guppy.nodes import GlobalName, LocalCall, LocalName +from guppy.nodes import GlobalName, LocalCall, LocalName, PyExpr # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -259,6 +260,36 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: else: raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]: + try: + python_val = eval( # noqa: S307, PGH001 + ast.unparse(node.value), + None, + DummyEvalDict(self.ctx, node.value), + ) + except DummyEvalDict.GuppyVarUsedError as e: + raise GuppyError( + f"Guppy variable `{e.var}` cannot be accessed in a compile-time " + "evaluated `py(...)` expression", + e.node or node, + ) from None + except Exception as e: # noqa: BLE001 + # Remove the top frame pointing to the `eval` call from the stack trace + tb = e.__traceback__.tb_next if e.__traceback__ else None + raise GuppyError( + "Error occurred while evaluating Python expression:\n\n" + + "".join(traceback.format_exception(type(e), e, tb)), + node, + ) from e + + if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals): + return with_loc(node, ast.Constant(value=python_val)), ty + + raise GuppyError( + f"Python expression of type `{type(python_val)}` is not supported by Guppy", + node, + ) + def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: raise InternalGuppyError( "BB contains `NamedExpr`. Should have been removed during CFG" @@ -360,5 +391,10 @@ def python_value_to_guppy_type( return globals.types["int"].build(node=node) case float(): return globals.types["float"].build(node=node) + case tuple(elts): + tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts] + if any(ty is None for ty in tys): + return None + return TupleType(cast(list[GuppyType], tys)) case _: return None diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index 20113685..d1604629 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -131,7 +131,7 @@ def check_nested_func_def( if not captured: # If there are no captured vars, we treat the function like a global name func = DefinedFunction(func_def.name, func_ty, func_def, None) - globals = ctx.globals | Globals({func_def.name: func}, {}) + globals = ctx.globals | Globals({func_def.name: func}, {}, {}) else: # Otherwise, we treat it like a local name diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index f74c18d2..b589684b 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -107,10 +107,17 @@ def python_value_to_hugr(v: Any) -> val.Value | None: """ from guppy.prelude._internal import bool_value, float_value, int_value - if isinstance(v, bool): - return bool_value(v) - elif isinstance(v, int): - return int_value(v) - elif isinstance(v, float): - return float_value(v) - return None + match v: + case bool(): + return bool_value(v) + case int(): + return int_value(v) + case float(): + return float_value(v) + case tuple(elts): + vs = [python_value_to_hugr(elt) for elt in elts] + if any(value is None for value in vs): + return None + return val.Tuple(vs=vs) + case _: + return None diff --git a/guppy/decorator.py b/guppy/decorator.py index 32f93ab1..06022169 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -1,9 +1,11 @@ import functools +import inspect from collections.abc import Callable from dataclasses import dataclass from typing import Any from guppy.ast_util import AstNode, has_empty_body +from guppy.checker.core import PyScope from guppy.custom import ( CustomCallChecker, CustomCallCompiler, @@ -46,7 +48,7 @@ def __call__(self, arg: PyFunc | GuppyModule) -> Hugr | None | FuncDecorator: def dec(f: Callable[..., Any]) -> Callable[..., Any]: assert isinstance(arg, GuppyModule) - arg.register_func_def(f) + arg.register_func_def(f, get_python_scope()) @functools.wraps(f) def dummy(*args: Any, **kwargs: Any) -> Any: @@ -59,7 +61,7 @@ def dummy(*args: Any, **kwargs: Any) -> Any: return dec else: module = self._module or GuppyModule("module") - module.register_func_def(arg) + module.register_func_def(arg, get_python_scope()) return module.compile() @pretty_errors @@ -68,7 +70,7 @@ def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorato module._instance_func_buffer = {} def dec(c: type) -> type: - module._register_buffered_instance_funcs(ty) + module._register_buffered_instance_funcs(ty, get_python_scope()) return c return dec @@ -118,7 +120,7 @@ def __str__(self) -> str: NewType.__name__ = name NewType.__qualname__ = _name module.register_type(_name, NewType) - module._register_buffered_instance_funcs(NewType) + module._register_buffered_instance_funcs(NewType, get_python_scope()) setattr(c, "_guppy_type", NewType) return c @@ -190,3 +192,37 @@ def dummy(*args: Any, **kwargs: Any) -> Any: guppy = _Guppy() + + +def get_python_scope() -> PyScope: + """Looks up all available Python variables from the call-site. + + Walks up the call stack until we have left the `guppy` module. + """ + # Note that this approach will yield unintended results if the user doesn't invoke + # the decorator directly. For example: + # + # def my_dec(f): + # some_local = ... + # return guppy(f) + # + # @my_dec + # def guppy_func(x: int) -> int: + # .... + # + # Here, we would reach the scope of `my_dec` and `some_local` would be available + # in the Guppy code. + # TODO: Is there a better way to obtain the variables in scope? Note that we + # could do `inspect.getclosurevars(f)` but it will fail if `f` has recursive + # calls. A custom solution based on `f.__code__.co_freevars` and + # `f.__closure__` would only work for CPython. + frame = inspect.currentframe() + if frame is None: + return {} + while frame.f_back is not None and frame.f_globals["__name__"].startswith("guppy."): + frame = frame.f_back + py_scope = frame.f_globals | frame.f_locals + # Explicitly delete frame to avoid reference cycle. + # See https://docs.python.org/3/library/inspect.html#the-interpreter-stack + del frame + return py_scope diff --git a/guppy/module.py b/guppy/module.py index 3061bafe..56ad303e 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -6,7 +6,7 @@ from typing import Any, Union from guppy.ast_util import AstNode, annotate_location -from guppy.checker.core import Globals, qualified_name +from guppy.checker.core import Globals, PyScope, qualified_name from guppy.checker.func_checker import DefinedFunction, check_global_func_def from guppy.compiler.core import CompiledGlobals from guppy.compiler.func_compiler import CompiledFunctionDef, compile_global_func_def @@ -37,7 +37,7 @@ class GuppyModule: _compiled_globals: CompiledGlobals # Mappings of functions defined in this module - _func_defs: dict[str, ast.FunctionDef] + _func_defs: dict[str, tuple[ast.FunctionDef, PyScope]] _func_decls: dict[str, ast.FunctionDef] _custom_funcs: dict[str, CustomFunction] @@ -48,7 +48,7 @@ class GuppyModule: def __init__(self, name: str, import_builtins: bool = True): self.name = name - self._globals = Globals({}, {}) + self._globals = Globals({}, {}, {}) self._compiled_globals = {} self._imported_globals = Globals.default() self._imported_compiled_globals = {} @@ -88,7 +88,7 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: self.load(val) def register_func_def( - self, f: PyFunc, instance: type[GuppyType] | None = None + self, f: PyFunc, py_scope: PyScope, instance: type[GuppyType] | None = None ) -> None: """Registers a Python function definition as belonging to this Guppy module.""" self._check_not_yet_compiled() @@ -100,7 +100,7 @@ def register_func_def( qualified_name(instance, func_ast.name) if instance else func_ast.name ) self._check_name_available(name, func_ast) - self._func_defs[name] = func_ast + self._func_defs[name] = func_ast, py_scope.copy() def register_func_decl(self, f: PyFunc) -> None: """Registers a Python function declaration as belonging to this Guppy module.""" @@ -126,7 +126,9 @@ def register_type(self, name: str, ty: type[GuppyType]) -> None: """Registers an existing Guppy type as belonging to this Guppy module.""" self._globals.types[name] = ty - def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: + def _register_buffered_instance_funcs( + self, instance: type[GuppyType], py_scope: PyScope + ) -> None: assert self._instance_func_buffer is not None buffer = self._instance_func_buffer self._instance_func_buffer = None @@ -134,7 +136,7 @@ def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: if isinstance(f, CustomFunction): self.register_custom_func(f, instance) else: - self.register_func_def(f, instance) + self.register_func_def(f, py_scope, instance) @property def compiled(self) -> bool: @@ -151,7 +153,7 @@ def compile(self) -> Hugr | None: func.check_type(self._imported_globals | self._globals) defined_funcs = { x: DefinedFunction.from_ast(f, x, self._imported_globals | self._globals) - for x, f in self._func_defs.items() + for x, (f, _) in self._func_defs.items() } declared_funcs = { x: DeclaredFunction.from_ast(f, x, self._imported_globals | self._globals) @@ -163,7 +165,12 @@ def compile(self) -> Hugr | None: # Type check function definitions checked = { - x: check_global_func_def(f, self._imported_globals | self._globals) + x: check_global_func_def( + f, + self._imported_globals + | self._globals + | Globals({}, {}, self._func_defs[x][1]), + ) for x, f in defined_funcs.items() } diff --git a/guppy/nodes.py b/guppy/nodes.py index 22730db5..a906f5b4 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -50,6 +50,14 @@ class GlobalCall(ast.expr): ) +class PyExpr(ast.expr): + """A compile-time evaluated `py(...)` expression.""" + + value: ast.expr + + _fields = ("value",) + + class NestedFunctionDef(ast.FunctionDef): cfg: "CFG" ty: FunctionType diff --git a/tests/error/py_errors/__init__.py b/tests/error/py_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/py_errors/guppy_name1.err b/tests/error/py_errors/guppy_name1.err new file mode 100644 index 00000000..3a13623d --- /dev/null +++ b/tests/error/py_errors/guppy_name1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo(x: int) -> int: +6: return py(x + 1) + ^ +GuppyError: Guppy variable `x` cannot be accessed in a compile-time evaluated `py(...)` expression diff --git a/tests/error/py_errors/guppy_name1.py b/tests/error/py_errors/guppy_name1.py new file mode 100644 index 00000000..d1697d0a --- /dev/null +++ b/tests/error/py_errors/guppy_name1.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo(x: int) -> int: + return py(x + 1) diff --git a/tests/error/py_errors/guppy_name2.err b/tests/error/py_errors/guppy_name2.err new file mode 100644 index 00000000..fd6a5da0 --- /dev/null +++ b/tests/error/py_errors/guppy_name2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: @guppy +8: def foo(x: int) -> int: +9: return py(x + 1) + ^ +GuppyError: Guppy variable `x` cannot be accessed in a compile-time evaluated `py(...)` expression diff --git a/tests/error/py_errors/guppy_name2.py b/tests/error/py_errors/guppy_name2.py new file mode 100644 index 00000000..44ca9852 --- /dev/null +++ b/tests/error/py_errors/guppy_name2.py @@ -0,0 +1,9 @@ +from guppy.decorator import guppy + + +x = 42 + + +@guppy +def foo(x: int) -> int: + return py(x + 1) diff --git a/tests/error/py_errors/no_args.err b/tests/error/py_errors/no_args.err new file mode 100644 index 00000000..e287869e --- /dev/null +++ b/tests/error/py_errors/no_args.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return py() + ^^^^ +GuppyError: Compile-time evaluated `py(...)` expression requires an argument diff --git a/tests/error/py_errors/no_args.py b/tests/error/py_errors/no_args.py new file mode 100644 index 00000000..26bdda63 --- /dev/null +++ b/tests/error/py_errors/no_args.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo() -> int: + return py() diff --git a/tests/error/py_errors/python_err.err b/tests/error/py_errors/python_err.err new file mode 100644 index 00000000..9d4c1e70 --- /dev/null +++ b/tests/error/py_errors/python_err.err @@ -0,0 +1,12 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return py(1 / 0) + ^^^^^^^^^ +GuppyError: Error occurred while evaluating Python expression: + +Traceback (most recent call last): + File "", line 1, in +ZeroDivisionError: division by zero + diff --git a/tests/error/py_errors/python_err.py b/tests/error/py_errors/python_err.py new file mode 100644 index 00000000..803ec227 --- /dev/null +++ b/tests/error/py_errors/python_err.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo() -> int: + return py(1 / 0) diff --git a/tests/error/test_py_errors.py b/tests/error/test_py_errors.py new file mode 100644 index 00000000..0dacbd53 --- /dev/null +++ b/tests/error/test_py_errors.py @@ -0,0 +1,20 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "py_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_py_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py new file mode 100644 index 00000000..2e896aee --- /dev/null +++ b/tests/integration/test_py.py @@ -0,0 +1,62 @@ +from guppy.decorator import guppy +from tests.integration.util import py + + +def test_basic(validate): + x = 42 + + @guppy + def foo() -> int: + return py(x + 1) + + validate(foo) + + +def test_builtin(validate): + @guppy + def foo() -> int: + return py(len({"a": 1337, "b": None})) + + validate(foo) + + +def test_if(validate): + b = True + + @guppy + def foo() -> int: + if py(b or 1 > 6): + return 0 + return 1 + + validate(foo) + + +def test_redeclare_after(validate): + x = 1 + + @guppy + def foo() -> int: + return py(x) + + x = False + + validate(foo) + + +def test_tuple(validate): + @guppy + def foo() -> int: + x, y = py((1, False)) + return x + + validate(foo) + + +def test_tuple_implicit(validate): + @guppy + def foo() -> int: + x, y = py(1, False) + return x + + validate(foo) diff --git a/tests/integration/util.py b/tests/integration/util.py index a99478e2..89343c73 100644 --- a/tests/integration/util.py +++ b/tests/integration/util.py @@ -1,3 +1,5 @@ +from typing import Any + import validator @@ -13,3 +15,8 @@ def __matmul__(self, other): # Dummy names to import to avoid errors for `_@functional` pseudo-decorator: functional = Decorator() _ = Decorator() + + +def py(*args: Any) -> Any: + """Dummy function to import to avoid errors for `py(...)` expressions""" + raise NotImplementedError