diff --git a/capsula/_capsule.py b/capsula/_capsule.py index d180dca1..849b29fb 100644 --- a/capsula/_capsule.py +++ b/capsula/_capsule.py @@ -26,6 +26,11 @@ def __init__( class CapsuleItem(ABC): + @property + @abstractmethod + def abort_on_error(self) -> bool: + raise NotImplementedError + @abstractmethod def encapsulate(self) -> Any: raise NotImplementedError diff --git a/capsula/_context/_base.py b/capsula/_context/_base.py index 33b06ad8..72582df7 100644 --- a/capsula/_context/_base.py +++ b/capsula/_context/_base.py @@ -7,6 +7,7 @@ class ContextBase(CapsuleItem): _subclass_registry: Final[dict[str, type[ContextBase]]] = {} + abort_on_error: bool = False def __init_subclass__(cls, **kwargs: Any) -> None: if cls.__name__ in cls._subclass_registry: diff --git a/capsula/_run.py b/capsula/_run.py index f6fcd665..d8f71e03 100644 --- a/capsula/_run.py +++ b/capsula/_run.py @@ -1,6 +1,7 @@ from __future__ import annotations import inspect +import logging import queue import threading from collections import deque @@ -28,6 +29,8 @@ _P = ParamSpec("_P") _T = TypeVar("_T") +logger = logging.getLogger(__name__) + class FuncInfo(BaseModel): func: Callable[..., Any] @@ -236,7 +239,7 @@ def __exit__( ) -> None: self._get_run_stack().get(block=False) - def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: # noqa: C901 func_info = FuncInfo(func=self._func, args=args, kwargs=kwargs) if self._run_dir_generator is None: msg = "run_dir_generator must be set before calling the function." @@ -266,26 +269,32 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: in_run_enc.add_context(FunctionCallContext(self._func, args, kwargs)) - # TODO: `result` will not be defined if `self.func` raises an exception and it is caught by a watcher. - with self, in_run_enc, in_run_enc.watch(): - if self._pass_pre_run_capsule: - result = self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type] - else: - result = self._func(*args, **kwargs) - - in_run_capsule = in_run_enc.encapsulate() - for reporter_generator in self._in_run_reporter_generators: - reporter = reporter_generator(params) - reporter.report(in_run_capsule) - - params.phase = "post" - post_run_enc = Encapsulator() - for context_generator in self._post_run_context_generators: - context = context_generator(params) - post_run_enc.add_context(context) - post_run_capsule = post_run_enc.encapsulate() - for reporter_generator in self._post_run_reporter_generators: - reporter = reporter_generator(params) - reporter.report(post_run_capsule) + try: + with self, in_run_enc, in_run_enc.watch(): + if self._pass_pre_run_capsule: + result = self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type] + else: + result = self._func(*args, **kwargs) + finally: + in_run_capsule = in_run_enc.encapsulate() + for reporter_generator in self._in_run_reporter_generators: + reporter = reporter_generator(params) + try: + reporter.report(in_run_capsule) + except Exception: + logger.exception(f"Failed to report in-run capsule with reporter {reporter}.") + + params.phase = "post" + post_run_enc = Encapsulator() + for context_generator in self._post_run_context_generators: + context = context_generator(params) + post_run_enc.add_context(context) + post_run_capsule = post_run_enc.encapsulate() + for reporter_generator in self._post_run_reporter_generators: + reporter = reporter_generator(params) + try: + reporter.report(post_run_capsule) + except Exception: + logger.exception(f"Failed to report post-run capsule with reporter {reporter}.") return result diff --git a/capsula/_watcher/_base.py b/capsula/_watcher/_base.py index 6e8f63e8..d2731bb4 100644 --- a/capsula/_watcher/_base.py +++ b/capsula/_watcher/_base.py @@ -1,7 +1,7 @@ from __future__ import annotations import queue -from abc import abstractmethod +from abc import ABC, abstractmethod from collections.abc import Hashable from typing import TYPE_CHECKING, Any, Dict, Final, Generic, OrderedDict, TypeVar @@ -12,8 +12,9 @@ from types import TracebackType -class WatcherBase(CapsuleItem): +class WatcherBase(CapsuleItem, ABC): _subclass_registry: Final[dict[str, type[WatcherBase]]] = {} + abort_on_error: bool = False def __init_subclass__(cls, **kwargs: Any) -> None: if cls.__name__ in cls._subclass_registry: diff --git a/capsula/_watcher/_exception.py b/capsula/_watcher/_exception.py index a2b7befd..0fc072ee 100644 --- a/capsula/_watcher/_exception.py +++ b/capsula/_watcher/_exception.py @@ -20,11 +20,9 @@ def __init__( name: str = "exception", *, base: type[BaseException] = Exception, - reraise: bool = False, ) -> None: self.name = name self.base = base - self.reraise = reraise self.exception: BaseException | None = None def encapsulate(self) -> ExceptionInfo: @@ -36,10 +34,9 @@ def watch(self) -> Iterator[None]: try: yield except self.base as e: - logger.debug(f"UncaughtExceptionWatcher: {self.name} caught exception: {e}") + logger.debug(f"UncaughtExceptionWatcher: {self.name} observed exception: {e}") self.exception = e - if self.reraise: - raise + raise def default_key(self) -> tuple[str, str]: return ("exception", self.name) diff --git a/capsula/_watcher/_time.py b/capsula/_watcher/_time.py index f84aa383..4ac555ef 100644 --- a/capsula/_watcher/_time.py +++ b/capsula/_watcher/_time.py @@ -25,10 +25,12 @@ def encapsulate(self) -> timedelta | None: @contextmanager def watch(self) -> Iterator[None]: start = time.perf_counter() - yield - end = time.perf_counter() - self.duration = timedelta(seconds=end - start) - logger.debug(f"TimeWatcher: {self.name} took {self.duration}.") + try: + yield + finally: + end = time.perf_counter() + self.duration = timedelta(seconds=end - start) + logger.debug(f"TimeWatcher: {self.name} took {self.duration}.") def default_key(self) -> tuple[str, str]: return ("time", self.name) diff --git a/capsula/encapsulator.py b/capsula/encapsulator.py index 5e1b00f7..ffe0abff 100644 --- a/capsula/encapsulator.py +++ b/capsula/encapsulator.py @@ -83,14 +83,14 @@ def add_watcher(self, watcher: WatcherBase, key: _CapsuleItemKey | None = None) raise KeyConflictError(key) self.watchers[key] = watcher - def encapsulate(self, *, abort_on_error: bool = True) -> Capsule: + def encapsulate(self) -> Capsule: data = {} fails = {} for key, capsule_item in chain(self.contexts.items(), self.watchers.items()): try: data[key] = capsule_item.encapsulate() except Exception as e: # noqa: PERF203 - if abort_on_error: + if capsule_item.abort_on_error: raise fails[key] = ExceptionInfo.from_exception(e) return Capsule(data, fails) diff --git a/examples/low_level.py b/examples/low_level.py index 3a41dcd6..733ae485 100644 --- a/examples/low_level.py +++ b/examples/low_level.py @@ -81,7 +81,7 @@ in_run_enc.add_watcher(capsula.TimeWatcher("calculation_time")) # Catch the exception raised by the encapsulated function. -in_run_enc.add_watcher(capsula.UncaughtExceptionWatcher("Exception", base=Exception, reraise=False)) +in_run_enc.add_watcher(capsula.UncaughtExceptionWatcher("Exception", base=Exception)) with in_run_enc.watch(): logger.info(f"Calculating pi with {N_SAMPLES} samples.")