Skip to content

Commit

Permalink
Merge pull request #251 from shunichironomura/improve-exception-handling
Browse files Browse the repository at this point in the history
Improve exception handling
  • Loading branch information
shunichironomura authored Jun 29, 2024
2 parents 5c3422d + 4ef8e21 commit 35ae35a
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 36 deletions.
5 changes: 5 additions & 0 deletions capsula/_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def __init__(


class CapsuleItem(ABC):
@property
@abstractmethod
def abort_on_error(self) -> bool:
raise NotImplementedError

@abstractmethod
def encapsulate(self) -> Any:
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions capsula/_context/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class ContextBase(CapsuleItem):
_subclass_registry: Final[dict[str, type[ContextBase]]] = {}
abort_on_error: bool = False

def __init_subclass__(cls, **kwargs: Any) -> None:
if cls.__name__ in cls._subclass_registry:
Expand Down
53 changes: 31 additions & 22 deletions capsula/_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
import logging
import queue
import threading
from collections import deque
Expand Down Expand Up @@ -28,6 +29,8 @@
_P = ParamSpec("_P")
_T = TypeVar("_T")

logger = logging.getLogger(__name__)


class FuncInfo(BaseModel):
func: Callable[..., Any]
Expand Down Expand Up @@ -236,7 +239,7 @@ def __exit__(
) -> None:
self._get_run_stack().get(block=False)

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: # noqa: C901
func_info = FuncInfo(func=self._func, args=args, kwargs=kwargs)
if self._run_dir_generator is None:
msg = "run_dir_generator must be set before calling the function."
Expand Down Expand Up @@ -266,26 +269,32 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:

in_run_enc.add_context(FunctionCallContext(self._func, args, kwargs))

# TODO: `result` will not be defined if `self.func` raises an exception and it is caught by a watcher.
with self, in_run_enc, in_run_enc.watch():
if self._pass_pre_run_capsule:
result = self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type]
else:
result = self._func(*args, **kwargs)

in_run_capsule = in_run_enc.encapsulate()
for reporter_generator in self._in_run_reporter_generators:
reporter = reporter_generator(params)
reporter.report(in_run_capsule)

params.phase = "post"
post_run_enc = Encapsulator()
for context_generator in self._post_run_context_generators:
context = context_generator(params)
post_run_enc.add_context(context)
post_run_capsule = post_run_enc.encapsulate()
for reporter_generator in self._post_run_reporter_generators:
reporter = reporter_generator(params)
reporter.report(post_run_capsule)
try:
with self, in_run_enc, in_run_enc.watch():
if self._pass_pre_run_capsule:
result = self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type]
else:
result = self._func(*args, **kwargs)
finally:
in_run_capsule = in_run_enc.encapsulate()
for reporter_generator in self._in_run_reporter_generators:
reporter = reporter_generator(params)
try:
reporter.report(in_run_capsule)
except Exception:
logger.exception(f"Failed to report in-run capsule with reporter {reporter}.")

params.phase = "post"
post_run_enc = Encapsulator()
for context_generator in self._post_run_context_generators:
context = context_generator(params)
post_run_enc.add_context(context)
post_run_capsule = post_run_enc.encapsulate()
for reporter_generator in self._post_run_reporter_generators:
reporter = reporter_generator(params)
try:
reporter.report(post_run_capsule)
except Exception:
logger.exception(f"Failed to report post-run capsule with reporter {reporter}.")

return result
5 changes: 3 additions & 2 deletions capsula/_watcher/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import queue
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections.abc import Hashable
from typing import TYPE_CHECKING, Any, Dict, Final, Generic, OrderedDict, TypeVar

Expand All @@ -12,8 +12,9 @@
from types import TracebackType


class WatcherBase(CapsuleItem):
class WatcherBase(CapsuleItem, ABC):
_subclass_registry: Final[dict[str, type[WatcherBase]]] = {}
abort_on_error: bool = False

def __init_subclass__(cls, **kwargs: Any) -> None:
if cls.__name__ in cls._subclass_registry:
Expand Down
7 changes: 2 additions & 5 deletions capsula/_watcher/_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ def __init__(
name: str = "exception",
*,
base: type[BaseException] = Exception,
reraise: bool = False,
) -> None:
self.name = name
self.base = base
self.reraise = reraise
self.exception: BaseException | None = None

def encapsulate(self) -> ExceptionInfo:
Expand All @@ -36,10 +34,9 @@ def watch(self) -> Iterator[None]:
try:
yield
except self.base as e:
logger.debug(f"UncaughtExceptionWatcher: {self.name} caught exception: {e}")
logger.debug(f"UncaughtExceptionWatcher: {self.name} observed exception: {e}")
self.exception = e
if self.reraise:
raise
raise

def default_key(self) -> tuple[str, str]:
return ("exception", self.name)
10 changes: 6 additions & 4 deletions capsula/_watcher/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def encapsulate(self) -> timedelta | None:
@contextmanager
def watch(self) -> Iterator[None]:
start = time.perf_counter()
yield
end = time.perf_counter()
self.duration = timedelta(seconds=end - start)
logger.debug(f"TimeWatcher: {self.name} took {self.duration}.")
try:
yield
finally:
end = time.perf_counter()
self.duration = timedelta(seconds=end - start)
logger.debug(f"TimeWatcher: {self.name} took {self.duration}.")

def default_key(self) -> tuple[str, str]:
return ("time", self.name)
4 changes: 2 additions & 2 deletions capsula/encapsulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ def add_watcher(self, watcher: WatcherBase, key: _CapsuleItemKey | None = None)
raise KeyConflictError(key)
self.watchers[key] = watcher

def encapsulate(self, *, abort_on_error: bool = True) -> Capsule:
def encapsulate(self) -> Capsule:
data = {}
fails = {}
for key, capsule_item in chain(self.contexts.items(), self.watchers.items()):
try:
data[key] = capsule_item.encapsulate()
except Exception as e: # noqa: PERF203
if abort_on_error:
if capsule_item.abort_on_error:
raise
fails[key] = ExceptionInfo.from_exception(e)
return Capsule(data, fails)
Expand Down
2 changes: 1 addition & 1 deletion examples/low_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
in_run_enc.add_watcher(capsula.TimeWatcher("calculation_time"))

# Catch the exception raised by the encapsulated function.
in_run_enc.add_watcher(capsula.UncaughtExceptionWatcher("Exception", base=Exception, reraise=False))
in_run_enc.add_watcher(capsula.UncaughtExceptionWatcher("Exception", base=Exception))

with in_run_enc.watch():
logger.info(f"Calculating pi with {N_SAMPLES} samples.")
Expand Down

0 comments on commit 35ae35a

Please sign in to comment.