diff --git a/capsula/__init__.py b/capsula/__init__.py
index 7b1dd4f4..400990cf 100644
--- a/capsula/__init__.py
+++ b/capsula/__init__.py
@@ -8,9 +8,44 @@
"set_capsule_dir",
"set_capsule_name",
"Encapsulator",
+ "capsule",
+ "record",
+ "Run",
+ "ContextBase",
+ "CwdContext",
+ "EnvVarContext",
+ "GitRepositoryContext",
+ "FileContext",
+ "PlatformContext",
+ "CpuContext",
+ "CommandContext",
+ "JsonDumpReporter",
+ "ReporterBase",
+ "WatcherBase",
+ "TimeWatcher",
+ "watcher",
+ "reporter",
+ "context",
+ "UncaughtExceptionWatcher",
+ "run",
]
+from ._context import (
+ CommandContext,
+ ContextBase,
+ CpuContext,
+ CwdContext,
+ EnvVarContext,
+ FileContext,
+ GitRepositoryContext,
+ PlatformContext,
+)
+from ._decorator import capsule, context, reporter, run, watcher
from ._monitor import monitor
+from ._reporter import JsonDumpReporter, ReporterBase
+from ._root import record
+from ._run import Run
from ._version import __version__
+from ._watcher import TimeWatcher, UncaughtExceptionWatcher, WatcherBase
from .encapsulator import Encapsulator
from .exceptions import CapsulaConfigurationError, CapsulaError
from .globalvars import get_capsule_dir, get_capsule_name, set_capsule_dir, set_capsule_name
diff --git a/capsula/_backport.py b/capsula/_backport.py
index 91ad2872..00bb83d1 100644
--- a/capsula/_backport.py
+++ b/capsula/_backport.py
@@ -1,10 +1,6 @@
from __future__ import annotations
-__all__ = [
- "TypeAlias",
- "file_digest",
- "Self",
-]
+__all__ = ["TypeAlias", "file_digest", "Self", "ParamSpec"]
import hashlib
import sys
@@ -16,9 +12,9 @@
from typing_extensions import Self
if sys.version_info >= (3, 10):
- from typing import TypeAlias
+ from typing import ParamSpec, TypeAlias
else:
- from typing_extensions import TypeAlias
+ from typing_extensions import ParamSpec, TypeAlias
if sys.version_info >= (3, 11):
diff --git a/capsula/capsule.py b/capsula/_capsule.py
similarity index 100%
rename from capsula/capsule.py
rename to capsula/_capsule.py
diff --git a/capsula/context/__init__.py b/capsula/_context/__init__.py
similarity index 88%
rename from capsula/context/__init__.py
rename to capsula/_context/__init__.py
index 470e934d..23519d45 100644
--- a/capsula/context/__init__.py
+++ b/capsula/_context/__init__.py
@@ -1,5 +1,5 @@
__all__ = [
- "Context",
+ "ContextBase",
"CwdContext",
"EnvVarContext",
"GitRepositoryContext",
@@ -8,7 +8,7 @@
"CpuContext",
"CommandContext",
]
-from ._base import Context
+from ._base import ContextBase
from ._command import CommandContext
from ._cpu import CpuContext
from ._cwd import CwdContext
diff --git a/capsula/_context/_base.py b/capsula/_context/_base.py
new file mode 100644
index 00000000..eb5746da
--- /dev/null
+++ b/capsula/_context/_base.py
@@ -0,0 +1,5 @@
+from capsula._capsule import CapsuleItem
+
+
+class ContextBase(CapsuleItem):
+ pass
diff --git a/capsula/context/_command.py b/capsula/_context/_command.py
similarity index 93%
rename from capsula/context/_command.py
rename to capsula/_context/_command.py
index bf1a79dd..52e7c004 100644
--- a/capsula/context/_command.py
+++ b/capsula/_context/_command.py
@@ -7,12 +7,12 @@
if TYPE_CHECKING:
from pathlib import Path
-from ._base import Context
+from ._base import ContextBase
logger = logging.getLogger(__name__)
-class CommandContext(Context):
+class CommandContext(ContextBase):
def __init__(self, command: str, cwd: Path | None = None) -> None:
self.command = command
self.cwd = cwd
diff --git a/capsula/context/_cpu.py b/capsula/_context/_cpu.py
similarity index 71%
rename from capsula/context/_cpu.py
rename to capsula/_context/_cpu.py
index 7ab20029..38fb1a9d 100644
--- a/capsula/context/_cpu.py
+++ b/capsula/_context/_cpu.py
@@ -1,9 +1,9 @@
from cpuinfo import get_cpu_info
-from ._base import Context
+from ._base import ContextBase
-class CpuContext(Context):
+class CpuContext(ContextBase):
def encapsulate(self) -> dict:
return get_cpu_info()
diff --git a/capsula/context/_cwd.py b/capsula/_context/_cwd.py
similarity index 70%
rename from capsula/context/_cwd.py
rename to capsula/_context/_cwd.py
index d5458b6f..fa804ea1 100644
--- a/capsula/context/_cwd.py
+++ b/capsula/_context/_cwd.py
@@ -1,9 +1,9 @@
from pathlib import Path
-from ._base import Context
+from ._base import ContextBase
-class CwdContext(Context):
+class CwdContext(ContextBase):
def encapsulate(self) -> Path:
return Path.cwd()
diff --git a/capsula/context/_envvar.py b/capsula/_context/_envvar.py
similarity index 80%
rename from capsula/context/_envvar.py
rename to capsula/_context/_envvar.py
index 58dd5da4..96283bec 100644
--- a/capsula/context/_envvar.py
+++ b/capsula/_context/_envvar.py
@@ -2,10 +2,10 @@
import os
-from ._base import Context
+from ._base import ContextBase
-class EnvVarContext(Context):
+class EnvVarContext(ContextBase):
def __init__(self, name: str) -> None:
self.name = name
diff --git a/capsula/context/_file.py b/capsula/_context/_file.py
similarity index 83%
rename from capsula/context/_file.py
rename to capsula/_context/_file.py
index 4c1db765..5ce0f005 100644
--- a/capsula/context/_file.py
+++ b/capsula/_context/_file.py
@@ -7,12 +7,12 @@
from capsula._backport import file_digest
-from ._base import Context
+from ._base import ContextBase
logger = logging.getLogger(__name__)
-class FileContext(Context):
+class FileContext(ContextBase):
def __init__(
self,
path: Path | str,
@@ -32,13 +32,11 @@ def __init__(
else:
self.copy_to = tuple(Path(p) for p in copy_to)
- def normalize_copy_dst_path(p: Path) -> Path:
- if p.is_dir():
- return p / self.path.name
- else:
- return p
-
- self.copy_to = tuple(normalize_copy_dst_path(p) for p in self.copy_to)
+ def _normalize_copy_dst_path(self, p: Path) -> Path:
+ if p.is_dir():
+ return p / self.path.name
+ else:
+ return p
def encapsulate(self) -> dict:
if self.hash_algorithm is None:
@@ -47,6 +45,8 @@ def encapsulate(self) -> dict:
with self.path.open("rb") as f:
digest = file_digest(f, self.hash_algorithm).hexdigest()
+ self.copy_to = tuple(self._normalize_copy_dst_path(p) for p in self.copy_to)
+
info: dict = {
"hash": {
"algorithm": self.hash_algorithm,
diff --git a/capsula/context/_git.py b/capsula/_context/_git.py
similarity index 66%
rename from capsula/context/_git.py
rename to capsula/_context/_git.py
index 4cc3ca92..0c53216d 100644
--- a/capsula/context/_git.py
+++ b/capsula/_context/_git.py
@@ -1,13 +1,18 @@
from __future__ import annotations
+import inspect
import logging
from pathlib import Path
+from typing import TYPE_CHECKING, Callable
from git.repo import Repo
from capsula.exceptions import CapsulaError
-from ._base import Context
+from ._base import ContextBase
+
+if TYPE_CHECKING:
+ from capsula._decorator import CapsuleParams
logger = logging.getLogger(__name__)
@@ -18,7 +23,7 @@ def __init__(self, repo: Repo) -> None:
super().__init__(f"Repository {repo.working_dir} is dirty")
-class GitRepositoryContext(Context):
+class GitRepositoryContext(ContextBase):
def __init__(
self,
name: str,
@@ -57,3 +62,19 @@ def encapsulate(self) -> dict:
def default_key(self) -> tuple[str, str]:
return ("git", self.name)
+
+ @classmethod
+ def default(cls) -> Callable[[CapsuleParams], GitRepositoryContext]:
+ def callback(params: CapsuleParams) -> GitRepositoryContext:
+ func_file_path = Path(inspect.getfile(params.func))
+ repo = Repo(func_file_path.parent, search_parent_directories=True)
+ repo_name = Path(repo.working_dir).name
+ return cls(
+ name=Path(repo.working_dir).name,
+ path=Path(repo.working_dir),
+ diff_file=params.run_dir / f"{repo_name}.diff",
+ search_parent_directories=False,
+ allow_dirty=True,
+ )
+
+ return callback
diff --git a/capsula/context/_platform.py b/capsula/_context/_platform.py
similarity index 93%
rename from capsula/context/_platform.py
rename to capsula/_context/_platform.py
index c365f5c1..d401fdfd 100644
--- a/capsula/context/_platform.py
+++ b/capsula/_context/_platform.py
@@ -1,9 +1,9 @@
import platform as pf
-from ._base import Context
+from ._base import ContextBase
-class PlatformContext(Context):
+class PlatformContext(ContextBase):
def encapsulate(self) -> dict:
return {
"machine": pf.machine(),
diff --git a/capsula/_decorator.py b/capsula/_decorator.py
new file mode 100644
index 00000000..bb87b35a
--- /dev/null
+++ b/capsula/_decorator.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+from functools import wraps
+from pathlib import Path
+from typing import TYPE_CHECKING, Callable, Literal, Tuple, TypeVar, Union
+
+from capsula._reporter import ReporterBase
+from capsula.encapsulator import Encapsulator
+
+from ._backport import ParamSpec
+from ._context import ContextBase
+from ._run import CapsuleParams, FuncInfo, Run
+from ._watcher import WatcherBase
+
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+
+ from ._backport import TypeAlias
+
+_P = ParamSpec("_P")
+_T = TypeVar("_T")
+
+
+_ContextInput: TypeAlias = Union[
+ ContextBase,
+ Tuple[ContextBase, Tuple[str, ...]],
+ Callable[[Path, Callable], Union[ContextBase, Tuple[ContextBase, Tuple[str, ...]]]],
+]
+_WatcherInput: TypeAlias = Union[
+ WatcherBase,
+ Tuple[WatcherBase, Tuple[str, ...]],
+ Callable[[Path, Callable], Union[WatcherBase, Tuple[WatcherBase, Tuple[str, ...]]]],
+]
+_ReporterInput: TypeAlias = Union[ReporterBase, Callable[[Path, Callable], ReporterBase]]
+
+
+def capsule( # noqa: C901
+ capsule_directory: Path | str | None = None,
+ pre_run_contexts: Sequence[_ContextInput] | None = None,
+ pre_run_reporters: Sequence[_ReporterInput] | None = None,
+ in_run_watchers: Sequence[_WatcherInput] | None = None,
+ post_run_contexts: Sequence[_ContextInput] | None = None,
+) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
+ if capsule_directory is None:
+ raise NotImplementedError
+ capsule_directory = Path(capsule_directory)
+
+ assert pre_run_contexts is not None
+ assert pre_run_reporters is not None
+ assert in_run_watchers is not None
+ assert post_run_contexts is not None
+
+ def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
+ pre_run_enc = Encapsulator()
+ for cxt in pre_run_contexts:
+ if isinstance(cxt, ContextBase):
+ pre_run_enc.add_context(cxt)
+ elif isinstance(cxt, tuple):
+ pre_run_enc.add_context(cxt[0], key=cxt[1])
+ else:
+ cxt_hydrated = cxt(capsule_directory, func)
+ if isinstance(cxt_hydrated, ContextBase):
+ pre_run_enc.add_context(cxt_hydrated)
+ elif isinstance(cxt_hydrated, tuple):
+ pre_run_enc.add_context(cxt_hydrated[0], key=cxt_hydrated[1])
+
+ @wraps(func)
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
+ capsule_directory.mkdir(parents=True, exist_ok=True)
+ pre_run_capsule = pre_run_enc.encapsulate()
+ for reporter in pre_run_reporters:
+ if isinstance(reporter, ReporterBase):
+ reporter.report(pre_run_capsule)
+ else:
+ reporter(capsule_directory, func).report(pre_run_capsule)
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def watcher(
+ watcher: WatcherBase | Callable[[CapsuleParams], WatcherBase],
+) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
+ def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
+ func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
+ run = func_or_run if isinstance(func_or_run, Run) else Run(func)
+ run.add_watcher(watcher)
+ return run
+
+ return decorator
+
+
+def reporter(
+ reporter: ReporterBase | Callable[[CapsuleParams], ReporterBase],
+ mode: Literal["pre", "in", "post", "all"],
+) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
+ def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
+ func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
+ run = func_or_run if isinstance(func_or_run, Run) else Run(func)
+ run.add_reporter(reporter, mode=mode)
+ return run
+
+ return decorator
+
+
+def context(
+ context: ContextBase | Callable[[CapsuleParams], ContextBase],
+ mode: Literal["pre", "post", "all"],
+) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
+ def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
+ func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
+ run = func_or_run if isinstance(func_or_run, Run) else Run(func)
+ run.add_context(context, mode=mode)
+ return run
+
+ return decorator
+
+
+def run(
+ run_dir: Path | Callable[[FuncInfo], Path],
+) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
+ def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
+ func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
+ run = func_or_run if isinstance(func_or_run, Run) else Run(func)
+ run.set_run_dir(run_dir)
+
+ return run
+
+ return decorator
diff --git a/capsula/_reporter/__init__.py b/capsula/_reporter/__init__.py
new file mode 100644
index 00000000..fff6b94a
--- /dev/null
+++ b/capsula/_reporter/__init__.py
@@ -0,0 +1,3 @@
+__all__ = ["JsonDumpReporter", "ReporterBase"]
+from ._base import ReporterBase
+from ._json import JsonDumpReporter
diff --git a/capsula/reporter/_base.py b/capsula/_reporter/_base.py
similarity index 87%
rename from capsula/reporter/_base.py
rename to capsula/_reporter/_base.py
index fe8c4c84..e1926e9d 100644
--- a/capsula/reporter/_base.py
+++ b/capsula/_reporter/_base.py
@@ -3,7 +3,7 @@
from capsula.encapsulator import Capsule
-class Reporter(ABC):
+class ReporterBase(ABC):
@abstractmethod
def report(self, capsule: Capsule) -> None:
raise NotImplementedError
diff --git a/capsula/reporter/_json.py b/capsula/_reporter/_json.py
similarity index 78%
rename from capsula/reporter/_json.py
rename to capsula/_reporter/_json.py
index 02416962..0661e17c 100644
--- a/capsula/reporter/_json.py
+++ b/capsula/_reporter/_json.py
@@ -1,8 +1,10 @@
from __future__ import annotations
import logging
+import traceback
from datetime import timedelta
from pathlib import Path
+from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Optional
import orjson
@@ -12,30 +14,37 @@
if TYPE_CHECKING:
from capsula.encapsulator import Capsule
-from ._base import Reporter
+from ._base import ReporterBase
logger = logging.getLogger(__name__)
def default_preset(obj: Any) -> Any:
- # if isinstance(obj, timedelta):
- # return str(obj)
if isinstance(obj, Path):
return str(obj)
if isinstance(obj, timedelta):
return str(obj)
+ if isinstance(obj, type) and issubclass(obj, BaseException):
+ return obj.__name__
+ if isinstance(obj, Exception):
+ return str(obj)
+ if isinstance(obj, TracebackType):
+ return "".join(traceback.format_tb(obj))
raise TypeError
-class JsonDumpReporter(Reporter):
+class JsonDumpReporter(ReporterBase):
def __init__(
self,
path: Path | str,
*,
default: Optional[Callable[[Any], Any]] = None,
option: Optional[int] = None,
+ mkdir: bool = True,
) -> None:
self.path = Path(path)
+ if mkdir:
+ self.path.parent.mkdir(parents=True, exist_ok=True)
if default is None:
self.default = default_preset
diff --git a/capsula/_root.py b/capsula/_root.py
new file mode 100644
index 00000000..bec1a441
--- /dev/null
+++ b/capsula/_root.py
@@ -0,0 +1,11 @@
+from typing import Any
+
+from .encapsulator import Encapsulator, _CapsuleItemKey
+
+
+def record(key: _CapsuleItemKey, value: Any) -> None:
+ enc = Encapsulator.get_current()
+ if enc is None:
+ msg = "No active encapsulator found."
+ raise RuntimeError(msg)
+ enc.record(key, value)
diff --git a/capsula/_run.py b/capsula/_run.py
new file mode 100644
index 00000000..bf0c38c8
--- /dev/null
+++ b/capsula/_run.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+import queue
+import threading
+from pathlib import Path
+from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar
+
+from pydantic import BaseModel
+
+from capsula._reporter import ReporterBase
+from capsula.encapsulator import Encapsulator
+
+from ._backport import ParamSpec, Self
+from ._context import ContextBase
+from ._watcher import WatcherBase
+
+if TYPE_CHECKING:
+ from types import TracebackType
+
+_P = ParamSpec("_P")
+_T = TypeVar("_T")
+
+
+class FuncInfo(BaseModel):
+ func: Callable
+ args: tuple
+ kwargs: dict
+
+
+class CapsuleParams(FuncInfo):
+ run_dir: Path
+ phase: Literal["pre", "in", "post"]
+
+
+class Run(Generic[_P, _T]):
+ _thread_local = threading.local()
+
+ @classmethod
+ def _get_run_stack(cls) -> queue.LifoQueue[Self]:
+ if not hasattr(cls._thread_local, "run_stack"):
+ cls._thread_local.run_stack = queue.LifoQueue()
+ return cls._thread_local.run_stack
+
+ @classmethod
+ def get_current(cls) -> Self | None:
+ try:
+ return cls._get_run_stack().queue[-1]
+ except IndexError:
+ return None
+
+ def __init__(self, func: Callable[_P, _T]) -> None:
+ self.pre_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = []
+ self.in_run_watcher_generators: list[Callable[[CapsuleParams], WatcherBase]] = []
+ self.post_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = []
+
+ self.pre_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = []
+ self.in_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = []
+ self.post_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = []
+
+ self.func = func
+ self.run_dir_generator: Callable[[FuncInfo], Path] | None = None
+
+ def add_context(
+ self,
+ context: ContextBase | Callable[[CapsuleParams], ContextBase],
+ *,
+ mode: Literal["pre", "post", "all"],
+ ) -> None:
+ def context_generator(params: CapsuleParams) -> ContextBase:
+ if isinstance(context, ContextBase):
+ return context
+ else:
+ return context(params)
+
+ if mode == "pre":
+ self.pre_run_context_generators.append(context_generator)
+ elif mode == "post":
+ self.post_run_context_generators.append(context_generator)
+ elif mode == "all":
+ self.pre_run_context_generators.append(context_generator)
+ self.post_run_context_generators.append(context_generator)
+ else:
+ msg = f"mode must be one of 'pre', 'post', or 'all', not {mode}."
+ raise ValueError(msg)
+
+ def add_watcher(self, watcher: WatcherBase | Callable[[CapsuleParams], WatcherBase]) -> None:
+ def watcher_generator(params: CapsuleParams) -> WatcherBase:
+ if isinstance(watcher, WatcherBase):
+ return watcher
+ else:
+ return watcher(params)
+
+ self.in_run_watcher_generators.append(watcher_generator)
+
+ def add_reporter(
+ self,
+ reporter: ReporterBase | Callable[[CapsuleParams], ReporterBase],
+ *,
+ mode: Literal["pre", "in", "post", "all"],
+ ) -> None:
+ def reporter_generator(params: CapsuleParams) -> ReporterBase:
+ if isinstance(reporter, ReporterBase):
+ return reporter
+ else:
+ return reporter(params)
+
+ if mode == "pre":
+ self.pre_run_reporter_generators.append(reporter_generator)
+ elif mode == "in":
+ self.in_run_reporter_generators.append(reporter_generator)
+ elif mode == "post":
+ self.post_run_reporter_generators.append(reporter_generator)
+ elif mode == "all":
+ self.pre_run_reporter_generators.append(reporter_generator)
+ self.in_run_reporter_generators.append(reporter_generator)
+ self.post_run_reporter_generators.append(reporter_generator)
+ else:
+ msg = f"mode must be one of 'pre', 'in', 'post', or 'all', not {mode}."
+ raise ValueError(msg)
+
+ def set_run_dir(self, run_dir: Path | Callable[[FuncInfo], Path]) -> None:
+ def run_dir_generator(params: FuncInfo) -> Path:
+ if isinstance(run_dir, Path):
+ return run_dir
+ else:
+ return run_dir(params)
+
+ self.run_dir_generator = run_dir_generator
+
+ def __enter__(self) -> Self:
+ self._get_run_stack().put(self)
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> None:
+ self._get_run_stack().get(block=False)
+
+ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
+ func_info = FuncInfo(func=self.func, args=args, kwargs=kwargs)
+ if self.run_dir_generator is None:
+ msg = "run_dir_generator must be set before calling the function."
+ raise ValueError(msg)
+ run_dir = self.run_dir_generator(func_info)
+ run_dir.mkdir(parents=True, exist_ok=True)
+ params = CapsuleParams(
+ func=func_info.func,
+ args=func_info.args,
+ kwargs=func_info.kwargs,
+ run_dir=run_dir,
+ phase="pre",
+ )
+
+ pre_run_enc = Encapsulator()
+ for context_generator in self.pre_run_context_generators:
+ context = context_generator(params)
+ pre_run_enc.add_context(context)
+ pre_run_capsule = pre_run_enc.encapsulate()
+ for reporter_generator in self.pre_run_reporter_generators:
+ reporter = reporter_generator(params)
+ reporter.report(pre_run_capsule)
+
+ params.phase = "in"
+ in_run_enc = Encapsulator()
+ for watcher_generator in self.in_run_watcher_generators:
+ watcher = watcher_generator(params)
+ in_run_enc.add_watcher(watcher)
+
+ with self, in_run_enc, in_run_enc.watch():
+ result = self.func(*args, **kwargs)
+
+ in_run_capsule = in_run_enc.encapsulate()
+ for reporter_generator in self.in_run_reporter_generators:
+ reporter = reporter_generator(params)
+ reporter.report(in_run_capsule)
+
+ params.phase = "post"
+ post_run_enc = Encapsulator()
+ for context_generator in self.post_run_context_generators:
+ context = context_generator(params)
+ post_run_enc.add_context(context)
+ post_run_capsule = post_run_enc.encapsulate()
+ for reporter_generator in self.post_run_reporter_generators:
+ reporter = reporter_generator(params)
+ reporter.report(post_run_capsule)
+
+ return result
diff --git a/capsula/_watcher/__init__.py b/capsula/_watcher/__init__.py
new file mode 100644
index 00000000..32a2fa0c
--- /dev/null
+++ b/capsula/_watcher/__init__.py
@@ -0,0 +1,4 @@
+__all__ = ["WatcherBase", "TimeWatcher", "UncaughtExceptionWatcher"]
+from ._base import WatcherBase
+from ._exception import UncaughtExceptionWatcher
+from ._time import TimeWatcher
diff --git a/capsula/watcher/_base.py b/capsula/_watcher/_base.py
similarity index 79%
rename from capsula/watcher/_base.py
rename to capsula/_watcher/_base.py
index 49d9defb..7127eb5b 100644
--- a/capsula/watcher/_base.py
+++ b/capsula/_watcher/_base.py
@@ -6,10 +6,10 @@
if TYPE_CHECKING:
from contextlib import AbstractContextManager
-from capsula.capsule import CapsuleItem
+from capsula._capsule import CapsuleItem
-class Watcher(CapsuleItem):
+class WatcherBase(CapsuleItem):
@abstractmethod
def watch(self) -> AbstractContextManager[None]:
raise NotImplementedError
diff --git a/capsula/watcher/_exception.py b/capsula/_watcher/_exception.py
similarity index 93%
rename from capsula/watcher/_exception.py
rename to capsula/_watcher/_exception.py
index e1dda665..19a13ed0 100644
--- a/capsula/watcher/_exception.py
+++ b/capsula/_watcher/_exception.py
@@ -9,12 +9,12 @@
from capsula.utils import ExceptionInfo
-from ._base import Watcher
+from ._base import WatcherBase
logger = logging.getLogger(__name__)
-class UncaughtExceptionWatcher(Watcher):
+class UncaughtExceptionWatcher(WatcherBase):
def __init__(self, name: str, *, base: type[BaseException] = Exception, reraise: bool = False) -> None:
self.name = name
self.base = base
diff --git a/capsula/watcher/_time.py b/capsula/_watcher/_time.py
similarity index 92%
rename from capsula/watcher/_time.py
rename to capsula/_watcher/_time.py
index 35997ca5..c55e726d 100644
--- a/capsula/watcher/_time.py
+++ b/capsula/_watcher/_time.py
@@ -9,12 +9,12 @@
if TYPE_CHECKING:
from collections.abc import Iterator
-from ._base import Watcher
+from ._base import WatcherBase
logger = logging.getLogger(__name__)
-class TimeWatcher(Watcher):
+class TimeWatcher(WatcherBase):
def __init__(self, name: str) -> None:
self.name = name
self.duration: timedelta | None = None
diff --git a/capsula/capture.py b/capsula/capture.py
index 1544e71e..7618a875 100644
--- a/capsula/capture.py
+++ b/capsula/capture.py
@@ -5,8 +5,8 @@
import logging
import subprocess
-from capsula._context import Context
from capsula.config import CapsulaConfig
+from capsula.context import Context
logger = logging.getLogger(__name__)
diff --git a/capsula/_context.py b/capsula/context.py
similarity index 100%
rename from capsula/_context.py
rename to capsula/context.py
diff --git a/capsula/context/_base.py b/capsula/context/_base.py
deleted file mode 100644
index a5759c5a..00000000
--- a/capsula/context/_base.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from capsula.capsule import CapsuleItem
-
-
-class Context(CapsuleItem):
- pass
diff --git a/capsula/encapsulator.py b/capsula/encapsulator.py
index c5478ac1..6030cb99 100644
--- a/capsula/encapsulator.py
+++ b/capsula/encapsulator.py
@@ -1,24 +1,24 @@
from __future__ import annotations
import queue
+import threading
from collections import OrderedDict
from collections.abc import Hashable
from contextlib import AbstractContextManager
from itertools import chain
from typing import TYPE_CHECKING, Any, Generic, Tuple, TypeVar, Union
-if TYPE_CHECKING:
- from ._backport import TypeAlias
+from capsula.utils import ExceptionInfo
+
+from ._capsule import Capsule
+from ._context import ContextBase
+from ._watcher import WatcherBase
+from .exceptions import CapsulaError
if TYPE_CHECKING:
from types import TracebackType
-from capsula.utils import ExceptionInfo
-
-from .capsule import Capsule
-from .context import Context
-from .exceptions import CapsulaError
-from .watcher import Watcher
+ from ._backport import Self, TypeAlias
_CapsuleItemKey: TypeAlias = Union[str, Tuple[str, ...]]
@@ -28,7 +28,7 @@ def __init__(self, key: _CapsuleItemKey) -> None:
super().__init__(f"Capsule item with key {key} already exists.")
-class ObjectContext(Context):
+class ObjectContext(ContextBase):
def __init__(self, obj: Any) -> None:
self.obj = obj
@@ -37,7 +37,7 @@ def encapsulate(self) -> Any:
_K = TypeVar("_K", bound=Hashable)
-_V = TypeVar("_V", bound=Watcher)
+_V = TypeVar("_V", bound=WatcherBase)
class WatcherGroup(AbstractContextManager, Generic[_K, _V]):
@@ -64,7 +64,7 @@ def __exit__(
suppress_exception = False
while not self.context_manager_stack.empty():
- cm = self.context_manager_stack.get()
+ cm = self.context_manager_stack.get(block=False)
suppress = bool(cm.__exit__(exc_type, exc_value, traceback))
suppress_exception = suppress_exception or suppress
@@ -77,11 +77,38 @@ def __exit__(
class Encapsulator:
+ _thread_local = threading.local()
+
+ @classmethod
+ def _get_context_stack(cls) -> queue.LifoQueue[Self]:
+ if not hasattr(cls._thread_local, "context_stack"):
+ cls._thread_local.context_stack = queue.LifoQueue()
+ return cls._thread_local.context_stack
+
+ @classmethod
+ def get_current(cls) -> Self | None:
+ try:
+ return cls._get_context_stack().queue[-1]
+ except IndexError:
+ return None
+
def __init__(self) -> None:
- self.contexts: OrderedDict[_CapsuleItemKey, Context] = OrderedDict()
- self.watchers: OrderedDict[_CapsuleItemKey, Watcher] = OrderedDict()
+ self.contexts: OrderedDict[_CapsuleItemKey, ContextBase] = OrderedDict()
+ self.watchers: OrderedDict[_CapsuleItemKey, WatcherBase] = OrderedDict()
+
+ def __enter__(self) -> Self:
+ self._get_context_stack().put(self)
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> None:
+ self._get_context_stack().get(block=False)
- def add_context(self, context: Context, key: _CapsuleItemKey | None = None) -> None:
+ def add_context(self, context: ContextBase, key: _CapsuleItemKey | None = None) -> None:
if key is None:
key = context.default_key()
if key in self.contexts or key in self.watchers:
@@ -91,7 +118,7 @@ def add_context(self, context: Context, key: _CapsuleItemKey | None = None) -> N
def record(self, key: _CapsuleItemKey, record: Any) -> None:
self.add_context(ObjectContext(record), key)
- def add_watcher(self, watcher: Watcher, key: _CapsuleItemKey | None = None) -> None:
+ def add_watcher(self, watcher: WatcherBase, key: _CapsuleItemKey | None = None) -> None:
if key is None:
key = watcher.default_key()
if key in self.contexts or key in self.watchers:
@@ -110,5 +137,5 @@ def encapsulate(self, *, abort_on_error: bool = False) -> Capsule:
fails[key] = ExceptionInfo.from_exception(e)
return Capsule(data, fails)
- def watch(self) -> WatcherGroup[_CapsuleItemKey, Watcher]:
+ def watch(self) -> WatcherGroup[_CapsuleItemKey, WatcherBase]:
return WatcherGroup(self.watchers)
diff --git a/capsula/reporter/__init__.py b/capsula/reporter/__init__.py
deleted file mode 100644
index 178bd6a0..00000000
--- a/capsula/reporter/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-__all__ = ["JsonDumpReporter"]
-from ._json import JsonDumpReporter
diff --git a/capsula/watcher/__init__.py b/capsula/watcher/__init__.py
deleted file mode 100644
index aa9376d1..00000000
--- a/capsula/watcher/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-__all__ = ["Watcher", "TimeWatcher", "UncaughtExceptionWatcher"]
-from ._base import Watcher
-from ._exception import UncaughtExceptionWatcher
-from ._time import TimeWatcher
diff --git a/coverage/badge.svg b/coverage/badge.svg
index ab10c4c0..000194b0 100644
--- a/coverage/badge.svg
+++ b/coverage/badge.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/examples/decorator.py b/examples/decorator.py
new file mode 100644
index 00000000..96f8bb91
--- /dev/null
+++ b/examples/decorator.py
@@ -0,0 +1,62 @@
+import logging
+import random
+from datetime import UTC, datetime
+from pathlib import Path
+
+import orjson
+
+import capsula
+
+logger = logging.getLogger(__name__)
+
+
+@capsula.run(
+ run_dir=lambda _: Path(__file__).parents[1] / "vault" / datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S"),
+)
+@capsula.context(
+ lambda params: capsula.FileContext(
+ Path(__file__).parents[1] / "pyproject.toml",
+ hash_algorithm="sha256",
+ copy_to=params.run_dir,
+ ),
+ mode="pre",
+)
+@capsula.context(capsula.GitRepositoryContext.default(), mode="pre")
+@capsula.reporter(
+ lambda params: capsula.JsonDumpReporter(
+ params.run_dir / f"{params.phase}-run-report.json",
+ option=orjson.OPT_INDENT_2,
+ ),
+ mode="all",
+)
+@capsula.watcher(capsula.TimeWatcher("calculation_time"))
+@capsula.context(
+ lambda params: capsula.FileContext(
+ Path(__file__).parent / "pi.txt",
+ hash_algorithm="sha256",
+ move_to=params.run_dir,
+ ),
+ mode="post",
+)
+def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
+ logger.info(f"Calculating pi with {n_samples} samples.")
+ logger.debug(f"Seed: {seed}")
+ random.seed(seed)
+ xs = (random.random() for _ in range(n_samples)) # noqa: S311
+ ys = (random.random() for _ in range(n_samples)) # noqa: S311
+ inside = sum(x * x + y * y <= 1.0 for x, y in zip(xs, ys))
+
+ capsula.record("inside", inside)
+
+ pi_estimate = (4.0 * inside) / n_samples
+ logger.info(f"Pi estimate: {pi_estimate}")
+ capsula.record("pi_estimate", pi_estimate)
+ # raise CapsulaError("This is a test error.")
+
+ with (Path(__file__).parent / "pi.txt").open("w") as output_file:
+ output_file.write(str(pi_estimate))
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ calculate_pi(n_samples=1_000)
diff --git a/examples/enc_context_manager.py b/examples/enc_context_manager.py
new file mode 100644
index 00000000..c443487f
--- /dev/null
+++ b/examples/enc_context_manager.py
@@ -0,0 +1,51 @@
+import logging
+import random
+from datetime import UTC, datetime
+from pathlib import Path
+
+import orjson
+
+import capsula
+
+logger = logging.getLogger(__name__)
+
+logging.basicConfig(level=logging.DEBUG)
+
+
+def calc_pi(n_samples: int, seed: int) -> float:
+ random.seed(seed)
+ xs = (random.random() for _ in range(n_samples)) # noqa: S311
+ ys = (random.random() for _ in range(n_samples)) # noqa: S311
+ inside = sum(x * x + y * y <= 1.0 for x, y in zip(xs, ys))
+
+ capsula.record("inside", inside)
+
+ pi_estimate = (4.0 * inside) / n_samples
+ logger.info(f"Pi estimate: {pi_estimate}")
+ capsula.record("pi_estimate", pi_estimate)
+
+ return pi_estimate
+
+
+def main(n_samples: int, seed: int) -> None:
+ # Define the run name and create the capsule directory
+ run_name = datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S")
+ capsule_directory = Path(__file__).parents[1] / "vault" / run_name
+
+ with capsula.Encapsulator() as enc:
+ logger.info(f"Calculating pi with {n_samples} samples.")
+ logger.debug(f"Seed: {seed}")
+
+ pi_estimate = calc_pi(n_samples, seed)
+
+ with (Path(__file__).parent / "pi.txt").open("w") as output_file:
+ output_file.write(str(pi_estimate))
+
+ in_run_capsule = enc.encapsulate()
+
+ in_run_reporter = capsula.JsonDumpReporter(capsule_directory / "in_run_report.json", option=orjson.OPT_INDENT_2)
+ in_run_reporter.report(in_run_capsule)
+
+
+if __name__ == "__main__":
+ main(1000, 42)
diff --git a/examples/low_level.py b/examples/low_level.py
index 05d3950e..91cd20c1 100644
--- a/examples/low_level.py
+++ b/examples/low_level.py
@@ -5,18 +5,7 @@
import orjson
-from capsula import Encapsulator
-from capsula.context import (
- CommandContext,
- CpuContext,
- CwdContext,
- EnvVarContext,
- FileContext,
- GitRepositoryContext,
- PlatformContext,
-)
-from capsula.reporter import JsonDumpReporter
-from capsula.watcher import TimeWatcher, UncaughtExceptionWatcher
+import capsula
logger = logging.getLogger(__name__)
@@ -32,10 +21,10 @@
capsule_directory.mkdir(parents=True, exist_ok=True)
# Create an encapsulator
-pre_run_enc = Encapsulator()
+pre_run_enc = capsula.Encapsulator()
# Create a reporter
-pre_run_reporter = JsonDumpReporter(capsule_directory / "pre_run_report.json", option=orjson.OPT_INDENT_2)
+pre_run_reporter = capsula.JsonDumpReporter(capsule_directory / "pre_run_report.json", option=orjson.OPT_INDENT_2)
# slack_reporter = SlackReporter(
# webhook_url="https://hooks.slack.com/services/T01JZQZQZQZ/B01JZQZQZQZ/QQZQZQZQZQZQZQZQZQZQZQZ",
# channel="test",
@@ -45,7 +34,7 @@
# The order of the contexts is important.
pre_run_enc.record("run_name", run_name)
pre_run_enc.add_context(
- GitRepositoryContext(
+ capsula.GitRepositoryContext(
name="capsula",
path=Path(__file__).parents[1],
diff_file=capsule_directory / "capsula.diff",
@@ -53,26 +42,30 @@
),
key=("git", "capsula"),
)
-pre_run_enc.add_context(CpuContext())
-pre_run_enc.add_context(PlatformContext())
-pre_run_enc.add_context(CwdContext())
-pre_run_enc.add_context(EnvVarContext("HOME"), key=("env", "HOME"))
-pre_run_enc.add_context(EnvVarContext("PATH")) # Default key will be used
-pre_run_enc.add_context(CommandContext("poetry check --lock"))
+pre_run_enc.add_context(capsula.CpuContext())
+pre_run_enc.add_context(capsula.PlatformContext())
+pre_run_enc.add_context(capsula.CwdContext())
+pre_run_enc.add_context(capsula.EnvVarContext("HOME"), key=("env", "HOME"))
+pre_run_enc.add_context(capsula.EnvVarContext("PATH")) # Default key will be used
+pre_run_enc.add_context(capsula.CommandContext("poetry check --lock"))
# This will have a side effect
-pre_run_enc.add_context(CommandContext("pip freeze --exclude-editable > requirements.txt"))
+pre_run_enc.add_context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt"))
pre_run_enc.add_context(
- FileContext(
+ capsula.FileContext(
Path(__file__).parents[1] / "requirements.txt",
hash_algorithm="sha256",
move_to=capsule_directory,
),
)
pre_run_enc.add_context(
- FileContext(Path(__file__).parents[1] / "pyproject.toml", hash_algorithm="sha256", copy_to=capsule_directory),
+ capsula.FileContext(
+ Path(__file__).parents[1] / "pyproject.toml",
+ hash_algorithm="sha256",
+ copy_to=capsule_directory,
+ ),
)
pre_run_enc.add_context(
- FileContext(Path(__file__).parents[1] / "poetry.lock", hash_algorithm="sha256", copy_to=capsule_directory),
+ capsula.FileContext(Path(__file__).parents[1] / "poetry.lock", hash_algorithm="sha256", copy_to=capsule_directory),
)
pre_run_capsule = pre_run_enc.encapsulate()
@@ -80,15 +73,15 @@
# slack_reporter.report(pre_run_capsule)
# Actual calculation
-in_run_enc = Encapsulator()
-in_run_reporter = JsonDumpReporter(capsule_directory / "in_run_report.json", option=orjson.OPT_INDENT_2)
+in_run_enc = capsula.Encapsulator()
+in_run_reporter = capsula.JsonDumpReporter(capsule_directory / "in_run_report.json", option=orjson.OPT_INDENT_2)
# The order matters. The first watcher will be the innermost one.
# Record the time it takes to run the function.
-in_run_enc.add_watcher(TimeWatcher("calculation_time"))
+in_run_enc.add_watcher(capsula.TimeWatcher("calculation_time"))
# Catch the exception raised by the encapsulated function.
-in_run_enc.add_watcher(UncaughtExceptionWatcher("Exception", base=Exception, reraise=False))
+in_run_enc.add_watcher(capsula.UncaughtExceptionWatcher("Exception", base=Exception, reraise=False))
with in_run_enc.watch():
logger.info(f"Calculating pi with {N_SAMPLES} samples.")
@@ -109,7 +102,7 @@
output_file.write(str(pi_estimate))
in_run_enc.add_context(
- FileContext(Path(__file__).parent / "pi.txt", hash_algorithm="sha256", move_to=capsule_directory),
+ capsula.FileContext(Path(__file__).parent / "pi.txt", hash_algorithm="sha256", move_to=capsule_directory),
)
in_run_capsule = in_run_enc.encapsulate()