diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 65996f29..4c830104 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from functools import cached_property -from typing import Any +from typing import Any, cast from guppylang.ast_util import AstNode, annotate_location from guppylang.checker.core import Globals @@ -223,19 +223,62 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: return [constructor_def] +def is_running_ipython() -> bool: + """Checks if we are currently running in IPython""" + try: + return get_ipython() is not None # type: ignore[name-defined] + except NameError: + return False + + +def get_ipython_cell_sources() -> list[str]: + """Returns the source code of all cells in the running IPython session. + + See https://github.com/wandb/weave/pull/1864 + """ + shell = get_ipython() # type: ignore[name-defined] # noqa: F821 + if not hasattr(shell, "user_ns"): + raise AttributeError("Cannot access user namespace") + cells = cast(list[str], shell.user_ns["In"]) + # First cell is always empty + return cells[1:] + + def parse_py_class(cls: type) -> ast.ClassDef: """Parses a Python class object into an AST.""" - source_lines, line_offset = inspect.getsourcelines(cls) - source = "".join(source_lines) # Lines already have trailing \n's - source = textwrap.dedent(source) - cls_ast = ast.parse(source).body[0] - file = inspect.getsourcefile(cls) - if file is None: - raise GuppyError("Couldn't determine source file for class") - annotate_location(cls_ast, source, file, line_offset) - if not isinstance(cls_ast, ast.ClassDef): - raise GuppyError("Expected a class definition", cls_ast) - return cls_ast + # We cannot use `inspect.getsourcelines` if we're running in IPython. See + # - https://bugs.python.org/issue33826 + # - https://github.com/ipython/ipython/issues/11249 + # - https://github.com/wandb/weave/pull/1864 + if is_running_ipython(): + cell_sources = get_ipython_cell_sources() + # Search cells in reverse order to find the most recent version of the class + for i, cell_source in enumerate(reversed(cell_sources)): + try: + cell_ast = ast.parse(cell_source) + except SyntaxError: + continue + # Search body in reverse order to find the most recent version of the class + for node in reversed(cell_ast.body): + if getattr(node, "name", None) == cls.__name__: + cell_name = f"" + annotate_location(node, cell_source, cell_name, 1) + if not isinstance(node, ast.ClassDef): + raise GuppyError("Expected a class definition", node) + return node + raise ValueError(f"Couldn't find source for class `{cls.__name__}`") + else: + source_lines, line_offset = inspect.getsourcelines(cls) + source = "".join(source_lines) # Lines already have trailing \n's + source = textwrap.dedent(source) + cls_ast = ast.parse(source).body[0] + file = inspect.getsourcefile(cls) + if file is None: + raise GuppyError("Couldn't determine source file for class") + annotate_location(cls_ast, source, file, line_offset) + if not isinstance(cls_ast, ast.ClassDef): + raise GuppyError("Expected a class definition", cls_ast) + return cls_ast def try_parse_generic_base(node: ast.expr) -> list[ast.expr] | None: