Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve exception handling #251

Merged
merged 7 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions capsula/_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def __init__(


class CapsuleItem(ABC):
@property
@abstractmethod
def abort_on_error(self) -> bool:
raise NotImplementedError

@abstractmethod
def encapsulate(self) -> Any:
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions capsula/_context/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class ContextBase(CapsuleItem):
_subclass_registry: Final[dict[str, type[ContextBase]]] = {}
abort_on_error: bool = False

def __init_subclass__(cls, **kwargs: Any) -> None:
if cls.__name__ in cls._subclass_registry:
Expand Down
53 changes: 31 additions & 22 deletions capsula/_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
import logging
import queue
import threading
from collections import deque
Expand Down Expand Up @@ -28,6 +29,8 @@
_P = ParamSpec("_P")
_T = TypeVar("_T")

logger = logging.getLogger(__name__)


class FuncInfo(BaseModel):
func: Callable[..., Any]
Expand Down Expand Up @@ -236,7 +239,7 @@ def __exit__(
) -> None:
self._get_run_stack().get(block=False)

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: # noqa: C901
func_info = FuncInfo(func=self._func, args=args, kwargs=kwargs)
if self._run_dir_generator is None:
msg = "run_dir_generator must be set before calling the function."
Expand Down Expand Up @@ -266,26 +269,32 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:

in_run_enc.add_context(FunctionCallContext(self._func, args, kwargs))

# TODO: `result` will not be defined if `self.func` raises an exception and it is caught by a watcher.
with self, in_run_enc, in_run_enc.watch():
if self._pass_pre_run_capsule:
result = self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type]
else:
result = self._func(*args, **kwargs)

in_run_capsule = in_run_enc.encapsulate()
for reporter_generator in self._in_run_reporter_generators:
reporter = reporter_generator(params)
reporter.report(in_run_capsule)

params.phase = "post"
post_run_enc = Encapsulator()
for context_generator in self._post_run_context_generators:
context = context_generator(params)
post_run_enc.add_context(context)
post_run_capsule = post_run_enc.encapsulate()
for reporter_generator in self._post_run_reporter_generators:
reporter = reporter_generator(params)
reporter.report(post_run_capsule)
try:
with self, in_run_enc, in_run_enc.watch():
if self._pass_pre_run_capsule:
result = self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type]
else:
result = self._func(*args, **kwargs)
finally:
in_run_capsule = in_run_enc.encapsulate()
for reporter_generator in self._in_run_reporter_generators:
reporter = reporter_generator(params)
try:
reporter.report(in_run_capsule)
except Exception:
logger.exception(f"Failed to report in-run capsule with reporter {reporter}.")

params.phase = "post"
post_run_enc = Encapsulator()
for context_generator in self._post_run_context_generators:
context = context_generator(params)
post_run_enc.add_context(context)
post_run_capsule = post_run_enc.encapsulate()
for reporter_generator in self._post_run_reporter_generators:
reporter = reporter_generator(params)
try:
reporter.report(post_run_capsule)
except Exception:
logger.exception(f"Failed to report post-run capsule with reporter {reporter}.")

return result
5 changes: 3 additions & 2 deletions capsula/_watcher/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import queue
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Hashable
from typing import TYPE_CHECKING, Any, Dict, Final, Generic, OrderedDict, TypeVar

Expand All @@ -12,8 +12,9 @@
from types import TracebackType


class WatcherBase(CapsuleItem):
class WatcherBase(CapsuleItem, ABC):
_subclass_registry: Final[dict[str, type[WatcherBase]]] = {}
abort_on_error: bool = False

def __init_subclass__(cls, **kwargs: Any) -> None:
if cls.__name__ in cls._subclass_registry:
Expand Down
7 changes: 2 additions & 5 deletions capsula/_watcher/_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ def __init__(
name: str = "exception",
*,
base: type[BaseException] = Exception,
reraise: bool = False,
) -> None:
self.name = name
self.base = base
self.reraise = reraise
self.exception: BaseException | None = None

def encapsulate(self) -> ExceptionInfo:
Expand All @@ -36,10 +34,9 @@ def watch(self) -> Iterator[None]:
try:
yield
except self.base as e:
logger.debug(f"UncaughtExceptionWatcher: {self.name} caught exception: {e}")
logger.debug(f"UncaughtExceptionWatcher: {self.name} observed exception: {e}")
self.exception = e
if self.reraise:
raise
raise

def default_key(self) -> tuple[str, str]:
return ("exception", self.name)
10 changes: 6 additions & 4 deletions capsula/_watcher/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def encapsulate(self) -> timedelta | None:
@contextmanager
def watch(self) -> Iterator[None]:
start = time.perf_counter()
yield
end = time.perf_counter()
self.duration = timedelta(seconds=end - start)
logger.debug(f"TimeWatcher: {self.name} took {self.duration}.")
try:
yield
finally:
end = time.perf_counter()
self.duration = timedelta(seconds=end - start)
logger.debug(f"TimeWatcher: {self.name} took {self.duration}.")

def default_key(self) -> tuple[str, str]:
return ("time", self.name)
4 changes: 2 additions & 2 deletions capsula/encapsulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def add_watcher(self, watcher: WatcherBase, key: _CapsuleItemKey | None = None)
raise KeyConflictError(key)
self.watchers[key] = watcher

def encapsulate(self, *, abort_on_error: bool = True) -> Capsule:
def encapsulate(self) -> Capsule:
data = {}
fails = {}
for key, capsule_item in chain(self.contexts.items(), self.watchers.items()):
try:
data[key] = capsule_item.encapsulate()
except Exception as e: # noqa: PERF203
if abort_on_error:
if capsule_item.abort_on_error:
raise
fails[key] = ExceptionInfo.from_exception(e)
return Capsule(data, fails)
Expand Down
2 changes: 1 addition & 1 deletion examples/low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
in_run_enc.add_watcher(capsula.TimeWatcher("calculation_time"))

# Catch the exception raised by the encapsulated function.
in_run_enc.add_watcher(capsula.UncaughtExceptionWatcher("Exception", base=Exception, reraise=False))
in_run_enc.add_watcher(capsula.UncaughtExceptionWatcher("Exception", base=Exception))

with in_run_enc.watch():
logger.info(f"Calculating pi with {N_SAMPLES} samples.")
Expand Down