From f0435e7d539ef78fe6bec15188273d834bb999fd Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 18 Jan 2024 14:24:36 +0000 Subject: [PATCH 1/3] feat: Allow lists in py expressions --- guppylang/checker/expr_checker.py | 102 ++++++++++++++----- guppylang/compiler/expr_compiler.py | 24 ++++- guppylang/prelude/_internal.py | 14 ++- tests/error/py_errors/list_different_tys.err | 7 ++ tests/error/py_errors/list_different_tys.py | 6 ++ tests/error/py_errors/list_empty.err | 7 ++ tests/error/py_errors/list_empty.py | 6 ++ tests/integration/test_py.py | 33 ++++++ 8 files changed, 166 insertions(+), 33 deletions(-) create mode 100644 tests/error/py_errors/list_different_tys.err create mode 100644 tests/error/py_errors/list_different_tys.py create mode 100644 tests/error/py_errors/list_empty.err create mode 100644 tests/error/py_errors/list_empty.py diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 4b6a2ea2..4b7433af 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -230,6 +230,21 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: else: raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func) + def visit_PyExpr(self, node: PyExpr, ty: GuppyType) -> tuple[ast.expr, Subst]: + python_val = eval_py_expr(node, self.ctx) + if act := python_value_to_guppy_type(python_val, node, self.ctx.globals): + subst = unify(ty, act, {}) + if subst is None: + self._fail(ty, act, node) + act = act.substitute(subst) + subst = {x: s for x, s in subst.items() if x in ty.unsolved_vars} + return with_type(act, with_loc(node, ast.Constant(value=python_val))), subst + + raise GuppyError( + f"Python expression of type `{type(python_val)}` is not supported by Guppy", + node, + ) + def generic_visit(self, node: ast.expr, ty: GuppyType) -> tuple[ast.expr, Subst]: # Try to synthesize and then check if we can unify it with the given type node, synth = self._synthesize(node, allow_free_vars=False) @@ -497,34 +512,7 @@ def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]: ) def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]: - # The method we used for obtaining the Python variables in scope only works in - # CPython (see `get_py_scope()`). - if sys.implementation.name != "cpython": - raise GuppyError( - "Compile-time `py(...)` expressions are only supported in CPython", node - ) - - try: - python_val = eval( # noqa: S307, PGH001 - ast.unparse(node.value), - None, - DummyEvalDict(self.ctx, node.value), - ) - except DummyEvalDict.GuppyVarUsedError as e: - raise GuppyError( - f"Guppy variable `{e.var}` cannot be accessed in a compile-time " - "`py(...)` expression", - e.node or node, - ) from None - except Exception as e: # noqa: BLE001 - # Remove the top frame pointing to the `eval` call from the stack trace - tb = e.__traceback__.tb_next if e.__traceback__ else None - raise GuppyError( - "Error occurred while evaluating Python expression:\n\n" - + "".join(traceback.format_exception(type(e), e, tb)), - node, - ) from e - + python_val = eval_py_expr(node, self.ctx) if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals): return with_loc(node, ast.Constant(value=python_val)), ty @@ -898,6 +886,38 @@ def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None: return node, elt_ty +def eval_py_expr(node: PyExpr, ctx: Context) -> Any: + """Evaluates a `py(...)` expression.""" + # The method we used for obtaining the Python variables in scope only works in + # CPython (see `get_py_scope()`). + if sys.implementation.name != "cpython": + raise GuppyError( + "Compile-time `py(...)` expressions are only supported in CPython", node + ) + + try: + python_val = eval( # noqa: S307, PGH001 + ast.unparse(node.value), + None, + DummyEvalDict(ctx, node.value), + ) + except DummyEvalDict.GuppyVarUsedError as e: + raise GuppyError( + f"Guppy variable `{e.var}` cannot be accessed in a compile-time " + "`py(...)` expression", + e.node or node, + ) from None + except Exception as e: # noqa: BLE001 + # Remove the top frame pointing to the `eval` call from the stack trace + tb = e.__traceback__.tb_next if e.__traceback__ else None + raise GuppyError( + "Error occurred while evaluating Python expression:\n\n" + + "".join(traceback.format_exception(type(e), e, tb)), + node, + ) from e + return python_val + + def python_value_to_guppy_type( v: Any, node: ast.expr, globals: Globals ) -> GuppyType | None: @@ -917,5 +937,31 @@ def python_value_to_guppy_type( if any(ty is None for ty in tys): return None return TupleType(cast(list[GuppyType], tys)) + case list(): + return _python_list_to_guppy_type(v, node, globals) case _: return None + + +def _python_list_to_guppy_type( + vs: list[Any], node: ast.expr, globals: Globals +) -> ListType | None: + """Turns a Python list into a Guppy type. + + Returns `None` if the list contains different types or types that are not + representable in Guppy. + """ + if len(vs) == 0: + return ListType(ExistentialTypeVar.new("T", False)) + + # All the list elements must have a unifiable types + v, *rest = vs + el_ty = python_value_to_guppy_type(v, node, globals) + if el_ty is None: + return None + for v in rest: + ty = python_value_to_guppy_type(v, node, globals) + if ty is None or (subst := unify(ty, el_ty, {})) is None: + return None + el_ty = el_ty.substitute(subst) + return ListType(el_ty) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 7271a990..c7c6567c 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -16,7 +16,9 @@ BoolType, BoundTypeVar, FunctionType, + GuppyType, Inst, + ListType, NoneType, TupleType, type_to_row, @@ -135,7 +137,7 @@ def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]: self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) def visit_Constant(self, node: ast.Constant) -> OutPortV: - if value := python_value_to_hugr(node.value): + if value := python_value_to_hugr(node.value, get_type(node)): const = self.graph.add_constant(value, get_type(node)).out_port(0) return self.graph.add_load_constant(const).out_port(0) raise InternalGuppyError("Unsupported constant expression in compiler") @@ -294,12 +296,17 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool: return False -def python_value_to_hugr(v: Any) -> val.Value | None: +def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None: """Turns a Python value into a Hugr value. Returns None if the Python value cannot be represented in Guppy. """ - from guppylang.prelude._internal import bool_value, float_value, int_value + from guppylang.prelude._internal import ( + bool_value, + float_value, + int_value, + list_value, + ) match v: case bool(): @@ -309,9 +316,18 @@ def python_value_to_hugr(v: Any) -> val.Value | None: case float(): return float_value(v) case tuple(elts): - vs = [python_value_to_hugr(elt) for elt in elts] + assert isinstance(exp_ty, TupleType) + vs = [ + python_value_to_hugr(elt, ty) + for elt, ty in zip(elts, exp_ty.element_types) + ] if any(value is None for value in vs): return None return val.Tuple(vs=vs) + case list(elts): + assert isinstance(exp_ty, ListType) + return list_value( + [python_value_to_hugr(elt, exp_ty.element_type) for elt in elts] + ) case _: return None diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 4baf1966..3da1fef3 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -1,5 +1,5 @@ import ast -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel @@ -52,6 +52,13 @@ class ConstF64(BaseModel): value: float +class ListValue(BaseModel): + """Hugr representation of floats in the arithmetic extension.""" + + c: Literal["ListValue"] = "ListValue" + value: list[Any] + + def bool_value(b: bool) -> val.Value: """Returns the Hugr representation of a boolean value.""" return val.Sum(tag=int(b), value=val.Tuple(vs=[])) @@ -67,6 +74,11 @@ def float_value(f: float) -> val.Value: return val.ExtensionVal(c=(ConstF64(value=f),)) +def list_value(v: list[val.Value]) -> val.Value: + """Returns the Hugr representation of a list value.""" + return val.ExtensionVal(c=(ListValue(value=v),)) + + def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType: """Utility method to create Hugr logic ops.""" return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) diff --git a/tests/error/py_errors/list_different_tys.err b/tests/error/py_errors/list_different_tys.err new file mode 100644 index 00000000..dcb972d5 --- /dev/null +++ b/tests/error/py_errors/list_different_tys.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return py([1, 1.0]) + ^^^^^^^^^^^^ +GuppyError: Python expression of type `` is not supported by Guppy diff --git a/tests/error/py_errors/list_different_tys.py b/tests/error/py_errors/list_different_tys.py new file mode 100644 index 00000000..80f468c9 --- /dev/null +++ b/tests/error/py_errors/list_different_tys.py @@ -0,0 +1,6 @@ +from guppylang.decorator import guppy + + +@guppy +def foo() -> int: + return py([1, 1.0]) diff --git a/tests/error/py_errors/list_empty.err b/tests/error/py_errors/list_empty.err new file mode 100644 index 00000000..748c4047 --- /dev/null +++ b/tests/error/py_errors/list_empty.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> None: +6: xs = py([]) + ^^^^^^ +GuppyTypeError: Cannot infer type variable in expression of type `list[?T]` diff --git a/tests/error/py_errors/list_empty.py b/tests/error/py_errors/list_empty.py new file mode 100644 index 00000000..09889f52 --- /dev/null +++ b/tests/error/py_errors/list_empty.py @@ -0,0 +1,6 @@ +from guppylang.decorator import guppy + + +@guppy +def foo() -> None: + xs = py([]) diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py index 85b0ec1f..5f56d44d 100644 --- a/tests/integration/test_py.py +++ b/tests/integration/test_py.py @@ -60,3 +60,36 @@ def foo() -> int: return x validate(foo) + + +def test_list_basic(validate): + @guppy + def foo() -> list[int]: + xs = py([1, 2, 3]) + return xs + + validate(foo) + + +def test_list_empty(validate): + @guppy + def foo() -> list[int]: + return py([]) + + validate(foo) + + +def test_list_empty_nested(validate): + @guppy + def foo() -> None: + xs: list[tuple[int, list[bool]]] = py([(42, [])]) + + validate(foo) + + +def test_list_empty_multiple(validate): + @guppy + def foo() -> None: + xs: tuple[list[int], list[bool]] = py([], []) + + validate(foo) From 5fa46f9e20b50b255f4175f7c37824c663dd935f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 30 Jan 2024 09:31:45 +0000 Subject: [PATCH 2/3] Fix test decorator --- tests/error/py_errors/list_different_tys.err | 2 +- tests/error/py_errors/list_different_tys.py | 4 ++-- tests/error/py_errors/list_empty.err | 2 +- tests/error/py_errors/list_empty.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/error/py_errors/list_different_tys.err b/tests/error/py_errors/list_different_tys.err index dcb972d5..e972e47f 100644 --- a/tests/error/py_errors/list_different_tys.err +++ b/tests/error/py_errors/list_different_tys.err @@ -1,6 +1,6 @@ Guppy compilation failed. Error in file $FILE:6 -4: @guppy +4: @compile_guppy 5: def foo() -> int: 6: return py([1, 1.0]) ^^^^^^^^^^^^ diff --git a/tests/error/py_errors/list_different_tys.py b/tests/error/py_errors/list_different_tys.py index 80f468c9..377b114b 100644 --- a/tests/error/py_errors/list_different_tys.py +++ b/tests/error/py_errors/list_different_tys.py @@ -1,6 +1,6 @@ -from guppylang.decorator import guppy +from tests.util import compile_guppy -@guppy +@compile_guppy def foo() -> int: return py([1, 1.0]) diff --git a/tests/error/py_errors/list_empty.err b/tests/error/py_errors/list_empty.err index 748c4047..1663dcb1 100644 --- a/tests/error/py_errors/list_empty.err +++ b/tests/error/py_errors/list_empty.err @@ -1,6 +1,6 @@ Guppy compilation failed. Error in file $FILE:6 -4: @guppy +4: @compile_guppy 5: def foo() -> None: 6: xs = py([]) ^^^^^^ diff --git a/tests/error/py_errors/list_empty.py b/tests/error/py_errors/list_empty.py index 09889f52..3324daf3 100644 --- a/tests/error/py_errors/list_empty.py +++ b/tests/error/py_errors/list_empty.py @@ -1,6 +1,6 @@ -from guppylang.decorator import guppy +from tests.util import compile_guppy -@guppy +@compile_guppy def foo() -> None: xs = py([]) From 829bf20152ef350c9e8a8b3efa7e5f4dda4fe935 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 30 Jan 2024 09:42:37 +0000 Subject: [PATCH 3/3] Improve error message --- guppylang/checker/expr_checker.py | 4 +++- tests/error/py_errors/list_different_tys.err | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 63f547a8..36570037 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -984,7 +984,9 @@ def _python_list_to_guppy_type( return None for v in rest: ty = python_value_to_guppy_type(v, node, globals) - if ty is None or (subst := unify(ty, el_ty, {})) is None: + if ty is None: return None + if (subst := unify(ty, el_ty, {})) is None: + raise GuppyError("Python list contains elements with different types", node) el_ty = el_ty.substitute(subst) return ListType(el_ty) diff --git a/tests/error/py_errors/list_different_tys.err b/tests/error/py_errors/list_different_tys.err index e972e47f..f39f07f0 100644 --- a/tests/error/py_errors/list_different_tys.err +++ b/tests/error/py_errors/list_different_tys.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:6 5: def foo() -> int: 6: return py([1, 1.0]) ^^^^^^^^^^^^ -GuppyError: Python expression of type `` is not supported by Guppy +GuppyError: Python list contains elements with different types