Skip to content

Commit

Permalink
fix: Fix struct definitions in notebooks (#374)
Browse files Browse the repository at this point in the history
Fixes #373.

Trying to do `inspect.getsourclines` for a class defined in a notebook
yields an `OSError: source code not available`. See
* ipython/ipython#11249
* https://bugs.python.org/issue33826

The fix suggested in ipython/ipython#11249
only works for classes with methods. Instead, I'm using a solution
inspired by wandb/weave#1864 that retrieves the
cell sources of the active IPython session and searches them one by one.
  • Loading branch information
mark-koch authored Aug 12, 2024
1 parent 1dfebef commit b009465
Showing 1 changed file with 55 additions and 12 deletions.
67 changes: 55 additions & 12 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Any
from typing import Any, cast

from guppylang.ast_util import AstNode, annotate_location
from guppylang.checker.core import Globals
Expand Down Expand Up @@ -223,19 +223,62 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]:
return [constructor_def]


def is_running_ipython() -> bool:
"""Checks if we are currently running in IPython"""
try:
return get_ipython() is not None # type: ignore[name-defined]
except NameError:
return False


def get_ipython_cell_sources() -> list[str]:
"""Returns the source code of all cells in the running IPython session.
See https://github.com/wandb/weave/pull/1864
"""
shell = get_ipython() # type: ignore[name-defined] # noqa: F821
if not hasattr(shell, "user_ns"):
raise AttributeError("Cannot access user namespace")
cells = cast(list[str], shell.user_ns["In"])
# First cell is always empty
return cells[1:]


def parse_py_class(cls: type) -> ast.ClassDef:
"""Parses a Python class object into an AST."""
source_lines, line_offset = inspect.getsourcelines(cls)
source = "".join(source_lines) # Lines already have trailing \n's
source = textwrap.dedent(source)
cls_ast = ast.parse(source).body[0]
file = inspect.getsourcefile(cls)
if file is None:
raise GuppyError("Couldn't determine source file for class")
annotate_location(cls_ast, source, file, line_offset)
if not isinstance(cls_ast, ast.ClassDef):
raise GuppyError("Expected a class definition", cls_ast)
return cls_ast
# We cannot use `inspect.getsourcelines` if we're running in IPython. See
# - https://bugs.python.org/issue33826
# - https://github.com/ipython/ipython/issues/11249
# - https://github.com/wandb/weave/pull/1864
if is_running_ipython():
cell_sources = get_ipython_cell_sources()
# Search cells in reverse order to find the most recent version of the class
for i, cell_source in enumerate(reversed(cell_sources)):
try:
cell_ast = ast.parse(cell_source)
except SyntaxError:
continue
# Search body in reverse order to find the most recent version of the class
for node in reversed(cell_ast.body):
if getattr(node, "name", None) == cls.__name__:
cell_name = f"<In [{len(cell_sources) - i}]>"
annotate_location(node, cell_source, cell_name, 1)
if not isinstance(node, ast.ClassDef):
raise GuppyError("Expected a class definition", node)
return node
raise ValueError(f"Couldn't find source for class `{cls.__name__}`")
else:
source_lines, line_offset = inspect.getsourcelines(cls)
source = "".join(source_lines) # Lines already have trailing \n's
source = textwrap.dedent(source)
cls_ast = ast.parse(source).body[0]
file = inspect.getsourcefile(cls)
if file is None:
raise GuppyError("Couldn't determine source file for class")
annotate_location(cls_ast, source, file, line_offset)
if not isinstance(cls_ast, ast.ClassDef):
raise GuppyError("Expected a class definition", cls_ast)
return cls_ast


def try_parse_generic_base(node: ast.expr) -> list[ast.expr] | None:
Expand Down

0 comments on commit b009465

Please sign in to comment.