Skip to content

Commit

Permalink
fix: Stop exiting interpreter on error (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch authored Feb 1, 2024
1 parent 8221385 commit 728e449
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 27 deletions.
79 changes: 56 additions & 23 deletions guppylang/error.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import ast
import functools
import os
import sys
import textwrap
from collections.abc import Callable, Sequence
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, TypeVar, cast

from guppylang.ast_util import AstNode, get_file, get_line_offset, get_source
from guppylang.gtypes import BoundTypeVar, ExistentialTypeVar, FunctionType, GuppyType
from guppylang.hugr.hugr import Node, OutPortV

# Whether the interpreter should exit when a Guppy error occurs
EXIT_ON_ERROR: bool = True


@dataclass(frozen=True)
class SourceLoc:
Expand Down Expand Up @@ -136,6 +136,18 @@ def unsolved_vars(self) -> set[ExistentialTypeVar]:
return set()


ExceptHook = Callable[[type[BaseException], BaseException, TracebackType | None], Any]


@contextmanager
def exception_hook(hook: ExceptHook) -> Iterator[None]:
"""Sets a custom `excepthook` for the scope of a 'with' block."""
old_hook = sys.excepthook
sys.excepthook = hook
yield
sys.excepthook = old_hook


def format_source_location(
loc: ast.AST,
num_lines: int = 3,
Expand Down Expand Up @@ -169,27 +181,48 @@ def format_source_location(
def pretty_errors(f: FuncT) -> FuncT:
"""Decorator to print custom error banners when a `GuppyError` occurs."""

def hook(
excty: type[BaseException], err: BaseException, traceback: TracebackType | None
) -> None:
"""Custom `excepthook` that intercepts `GuppyExceptions` for pretty printing."""
# Fall back to default hook if it's not a GuppyException or we're missing an
# error location
if not isinstance(err, GuppyError) or err.location is None:
sys.__excepthook__(excty, err, traceback)
return

loc = err.location
file, line_offset = get_file(loc), get_line_offset(loc)
assert file is not None
assert line_offset is not None
line = line_offset + loc.lineno - 1
sys.stderr.write(
f"Guppy compilation failed. Error in file {file}:{line}\n\n"
f"{format_source_location(loc)}\n"
f"{err.__class__.__name__}: {err.get_msg()}\n",
)

@functools.wraps(f)
def pretty_errors_wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return f(*args, **kwargs)
except GuppyError as err:
# Reraise if we're missing a location
if not err.location:
with exception_hook(hook):
try:
return f(*args, **kwargs)
except GuppyError as err:
# For normal usage, this `try` block is not necessary since the
# excepthook is automatically invoked when the exception (which is being
# reraised below) is not handled. However, when running tests, we have
# to manually invoke the hook to print the error message, since the
# tests always have to capture exceptions.
if _pytest_running():
hook(type(err), err, err.__traceback__)
raise
loc = err.location
file, line_offset = get_file(loc), get_line_offset(loc)
assert file is not None
assert line_offset is not None
line = line_offset + loc.lineno - 1
print( # noqa: T201
f"Guppy compilation failed. Error in file {file}:{line}\n\n"
f"{format_source_location(loc)}\n"
f"{err.__class__.__name__}: {err.get_msg()}",
file=sys.stderr,
)
if EXIT_ON_ERROR:
sys.exit(1)
return None

return cast(FuncT, pretty_errors_wrapped)


def _pytest_running() -> bool:
"""Checks if we are currently running pytest.
See https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable
"""
return "PYTEST_CURRENT_TEST" in os.environ
6 changes: 2 additions & 4 deletions tests/error/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import pathlib
import pytest

from typing import Any
from collections.abc import Callable

from guppylang.error import GuppyError
from guppylang.hugr import tys
from guppylang.hugr.tys import TypeBound
from guppylang.module import GuppyModule
Expand All @@ -17,7 +15,7 @@ def run_error_test(file, capsys):
spec = importlib.util.spec_from_file_location("test_module", file)
py_module = importlib.util.module_from_spec(spec)

with pytest.raises(SystemExit):
with pytest.raises(GuppyError):
spec.loader.exec_module(py_module)

err = capsys.readouterr().err
Expand Down

0 comments on commit 728e449

Please sign in to comment.