From d5426017320efba8ed8cf2024c37a0b64b0cdce9 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 13 Aug 2024 13:53:42 +0100 Subject: [PATCH] feat: Use cell name instead of file for notebook errors (#382) This makes the compiler output for notebooks deterministic. Closes #381. Builds on top of the solution in #374, generalising from class to function definitions. --- guppylang/definition/function.py | 13 +++++++- guppylang/definition/struct.py | 47 +++++---------------------- guppylang/ipython_inspect.py | 56 ++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 39 deletions(-) create mode 100644 guppylang/ipython_inspect.py diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index befb21ae..0abf32b6 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -20,6 +20,7 @@ from guppylang.definition.value import CallableDef, CompiledCallableDef from guppylang.error import GuppyError from guppylang.hugr_builder.hugr import DFContainingVNode, Hugr, Node, OutPortV +from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.nodes import GlobalCall from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import FunctionType, Type, type_to_row @@ -175,7 +176,17 @@ def parse_py_func(f: PyFunc) -> tuple[ast.FunctionDef, str | None]: source = "".join(source_lines) # Lines already have trailing \n's source = textwrap.dedent(source) func_ast = ast.parse(source).body[0] - file = inspect.getsourcefile(f) + # In Jupyter notebooks, we shouldn't use `inspect.getsourcefile(f)` since it would + # only give us a dummy temporary file + file: str | None + if is_running_ipython(): + file = "" + if isinstance(func_ast, ast.FunctionDef): + defn = find_ipython_def(func_ast.name) + if defn is not None: + file = f"<{defn.cell_name}>" + else: + file = inspect.getsourcefile(f) if file is None: raise GuppyError("Couldn't determine source file for function") annotate_location(func_ast, source, file, line_offset) diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 4c830104..98c5b02e 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from functools import cached_property -from typing import Any, cast +from typing import Any from guppylang.ast_util import AstNode, annotate_location from guppylang.checker.core import Globals @@ -24,6 +24,7 @@ from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError, InternalGuppyError from guppylang.hugr_builder.hugr import OutPortV +from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.tys.arg import Argument from guppylang.tys.param import Parameter, check_all_args from guppylang.tys.parsing import type_from_ast @@ -223,27 +224,6 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: return [constructor_def] -def is_running_ipython() -> bool: - """Checks if we are currently running in IPython""" - try: - return get_ipython() is not None # type: ignore[name-defined] - except NameError: - return False - - -def get_ipython_cell_sources() -> list[str]: - """Returns the source code of all cells in the running IPython session. - - See https://github.com/wandb/weave/pull/1864 - """ - shell = get_ipython() # type: ignore[name-defined] # noqa: F821 - if not hasattr(shell, "user_ns"): - raise AttributeError("Cannot access user namespace") - cells = cast(list[str], shell.user_ns["In"]) - # First cell is always empty - return cells[1:] - - def parse_py_class(cls: type) -> ast.ClassDef: """Parses a Python class object into an AST.""" # We cannot use `inspect.getsourcelines` if we're running in IPython. See @@ -251,22 +231,13 @@ def parse_py_class(cls: type) -> ast.ClassDef: # - https://github.com/ipython/ipython/issues/11249 # - https://github.com/wandb/weave/pull/1864 if is_running_ipython(): - cell_sources = get_ipython_cell_sources() - # Search cells in reverse order to find the most recent version of the class - for i, cell_source in enumerate(reversed(cell_sources)): - try: - cell_ast = ast.parse(cell_source) - except SyntaxError: - continue - # Search body in reverse order to find the most recent version of the class - for node in reversed(cell_ast.body): - if getattr(node, "name", None) == cls.__name__: - cell_name = f"" - annotate_location(node, cell_source, cell_name, 1) - if not isinstance(node, ast.ClassDef): - raise GuppyError("Expected a class definition", node) - return node - raise ValueError(f"Couldn't find source for class `{cls.__name__}`") + defn = find_ipython_def(cls.__name__) + if defn is None: + raise ValueError(f"Couldn't find source for class `{cls.__name__}`") + annotate_location(defn.node, defn.cell_source, f"<{defn.cell_name}>", 1) + if not isinstance(defn.node, ast.ClassDef): + raise GuppyError("Expected a class definition", defn.node) + return defn.node else: source_lines, line_offset = inspect.getsourcelines(cls) source = "".join(source_lines) # Lines already have trailing \n's diff --git a/guppylang/ipython_inspect.py b/guppylang/ipython_inspect.py new file mode 100644 index 00000000..fcfb575c --- /dev/null +++ b/guppylang/ipython_inspect.py @@ -0,0 +1,56 @@ +"""Tools for inspecting source code when running in IPython.""" + +import ast +from typing import NamedTuple, cast + + +def is_running_ipython() -> bool: + """Checks if we are currently running in IPython""" + try: + return get_ipython() is not None # type: ignore[name-defined] + except NameError: + return False + + +def get_ipython_cell_sources() -> list[str]: + """Returns the source code of all cells in the running IPython session. + + See https://github.com/wandb/weave/pull/1864 + """ + shell = get_ipython() # type: ignore[name-defined] # noqa: F821 + if not hasattr(shell, "user_ns"): + raise AttributeError("Cannot access user namespace") + cells = cast(list[str], shell.user_ns["In"]) + # First cell is always empty + return cells[1:] + + +class IPythonDef(NamedTuple): + """AST of a definition in IPython together with the definition cell name.""" + + node: ast.FunctionDef | ast.ClassDef + cell_name: str + cell_source: str + + +def find_ipython_def(name: str) -> IPythonDef | None: + """Tries to find a definition matching a given name in the current IPython session. + + Note that this only finds *top-level* function or class definitions. Nested + definitions are not detected. + + See https://github.com/wandb/weave/pull/1864 + """ + cell_sources = get_ipython_cell_sources() + # Search cells in reverse order to find the most recent version of the definition + for i, cell_source in enumerate(reversed(cell_sources)): + try: + cell_ast = ast.parse(cell_source) + except SyntaxError: + continue + # Search body in reverse order to find the most recent version of the class + for node in reversed(cell_ast.body): + if isinstance(node, ast.FunctionDef | ast.ClassDef) and node.name == name: + cell_name = f"In [{len(cell_sources) - i}]" + return IPythonDef(node, cell_name, cell_source) + return None