diff --git a/capsula/__init__.py b/capsula/__init__.py index 7b1dd4f4..400990cf 100644 --- a/capsula/__init__.py +++ b/capsula/__init__.py @@ -8,9 +8,44 @@ "set_capsule_dir", "set_capsule_name", "Encapsulator", + "capsule", + "record", + "Run", + "ContextBase", + "CwdContext", + "EnvVarContext", + "GitRepositoryContext", + "FileContext", + "PlatformContext", + "CpuContext", + "CommandContext", + "JsonDumpReporter", + "ReporterBase", + "WatcherBase", + "TimeWatcher", + "watcher", + "reporter", + "context", + "UncaughtExceptionWatcher", + "run", ] +from ._context import ( + CommandContext, + ContextBase, + CpuContext, + CwdContext, + EnvVarContext, + FileContext, + GitRepositoryContext, + PlatformContext, +) +from ._decorator import capsule, context, reporter, run, watcher from ._monitor import monitor +from ._reporter import JsonDumpReporter, ReporterBase +from ._root import record +from ._run import Run from ._version import __version__ +from ._watcher import TimeWatcher, UncaughtExceptionWatcher, WatcherBase from .encapsulator import Encapsulator from .exceptions import CapsulaConfigurationError, CapsulaError from .globalvars import get_capsule_dir, get_capsule_name, set_capsule_dir, set_capsule_name diff --git a/capsula/_backport.py b/capsula/_backport.py index 91ad2872..00bb83d1 100644 --- a/capsula/_backport.py +++ b/capsula/_backport.py @@ -1,10 +1,6 @@ from __future__ import annotations -__all__ = [ - "TypeAlias", - "file_digest", - "Self", -] +__all__ = ["TypeAlias", "file_digest", "Self", "ParamSpec"] import hashlib import sys @@ -16,9 +12,9 @@ from typing_extensions import Self if sys.version_info >= (3, 10): - from typing import TypeAlias + from typing import ParamSpec, TypeAlias else: - from typing_extensions import TypeAlias + from typing_extensions import ParamSpec, TypeAlias if sys.version_info >= (3, 11): diff --git a/capsula/capsule.py b/capsula/_capsule.py similarity index 100% rename from capsula/capsule.py rename to capsula/_capsule.py diff --git a/capsula/context/__init__.py b/capsula/_context/__init__.py similarity index 88% rename from capsula/context/__init__.py rename to capsula/_context/__init__.py index 470e934d..23519d45 100644 --- a/capsula/context/__init__.py +++ b/capsula/_context/__init__.py @@ -1,5 +1,5 @@ __all__ = [ - "Context", + "ContextBase", "CwdContext", "EnvVarContext", "GitRepositoryContext", @@ -8,7 +8,7 @@ "CpuContext", "CommandContext", ] -from ._base import Context +from ._base import ContextBase from ._command import CommandContext from ._cpu import CpuContext from ._cwd import CwdContext diff --git a/capsula/_context/_base.py b/capsula/_context/_base.py new file mode 100644 index 00000000..eb5746da --- /dev/null +++ b/capsula/_context/_base.py @@ -0,0 +1,5 @@ +from capsula._capsule import CapsuleItem + + +class ContextBase(CapsuleItem): + pass diff --git a/capsula/context/_command.py b/capsula/_context/_command.py similarity index 93% rename from capsula/context/_command.py rename to capsula/_context/_command.py index bf1a79dd..52e7c004 100644 --- a/capsula/context/_command.py +++ b/capsula/_context/_command.py @@ -7,12 +7,12 @@ if TYPE_CHECKING: from pathlib import Path -from ._base import Context +from ._base import ContextBase logger = logging.getLogger(__name__) -class CommandContext(Context): +class CommandContext(ContextBase): def __init__(self, command: str, cwd: Path | None = None) -> None: self.command = command self.cwd = cwd diff --git a/capsula/context/_cpu.py b/capsula/_context/_cpu.py similarity index 71% rename from capsula/context/_cpu.py rename to capsula/_context/_cpu.py index 7ab20029..38fb1a9d 100644 --- a/capsula/context/_cpu.py +++ b/capsula/_context/_cpu.py @@ -1,9 +1,9 @@ from cpuinfo import get_cpu_info -from ._base import Context +from ._base import ContextBase -class CpuContext(Context): +class CpuContext(ContextBase): def encapsulate(self) -> dict: return get_cpu_info() diff --git a/capsula/context/_cwd.py b/capsula/_context/_cwd.py similarity index 70% rename from capsula/context/_cwd.py rename to capsula/_context/_cwd.py index d5458b6f..fa804ea1 100644 --- a/capsula/context/_cwd.py +++ b/capsula/_context/_cwd.py @@ -1,9 +1,9 @@ from pathlib import Path -from ._base import Context +from ._base import ContextBase -class CwdContext(Context): +class CwdContext(ContextBase): def encapsulate(self) -> Path: return Path.cwd() diff --git a/capsula/context/_envvar.py b/capsula/_context/_envvar.py similarity index 80% rename from capsula/context/_envvar.py rename to capsula/_context/_envvar.py index 58dd5da4..96283bec 100644 --- a/capsula/context/_envvar.py +++ b/capsula/_context/_envvar.py @@ -2,10 +2,10 @@ import os -from ._base import Context +from ._base import ContextBase -class EnvVarContext(Context): +class EnvVarContext(ContextBase): def __init__(self, name: str) -> None: self.name = name diff --git a/capsula/context/_file.py b/capsula/_context/_file.py similarity index 83% rename from capsula/context/_file.py rename to capsula/_context/_file.py index 4c1db765..5ce0f005 100644 --- a/capsula/context/_file.py +++ b/capsula/_context/_file.py @@ -7,12 +7,12 @@ from capsula._backport import file_digest -from ._base import Context +from ._base import ContextBase logger = logging.getLogger(__name__) -class FileContext(Context): +class FileContext(ContextBase): def __init__( self, path: Path | str, @@ -32,13 +32,11 @@ def __init__( else: self.copy_to = tuple(Path(p) for p in copy_to) - def normalize_copy_dst_path(p: Path) -> Path: - if p.is_dir(): - return p / self.path.name - else: - return p - - self.copy_to = tuple(normalize_copy_dst_path(p) for p in self.copy_to) + def _normalize_copy_dst_path(self, p: Path) -> Path: + if p.is_dir(): + return p / self.path.name + else: + return p def encapsulate(self) -> dict: if self.hash_algorithm is None: @@ -47,6 +45,8 @@ def encapsulate(self) -> dict: with self.path.open("rb") as f: digest = file_digest(f, self.hash_algorithm).hexdigest() + self.copy_to = tuple(self._normalize_copy_dst_path(p) for p in self.copy_to) + info: dict = { "hash": { "algorithm": self.hash_algorithm, diff --git a/capsula/context/_git.py b/capsula/_context/_git.py similarity index 66% rename from capsula/context/_git.py rename to capsula/_context/_git.py index 4cc3ca92..0c53216d 100644 --- a/capsula/context/_git.py +++ b/capsula/_context/_git.py @@ -1,13 +1,18 @@ from __future__ import annotations +import inspect import logging from pathlib import Path +from typing import TYPE_CHECKING, Callable from git.repo import Repo from capsula.exceptions import CapsulaError -from ._base import Context +from ._base import ContextBase + +if TYPE_CHECKING: + from capsula._decorator import CapsuleParams logger = logging.getLogger(__name__) @@ -18,7 +23,7 @@ def __init__(self, repo: Repo) -> None: super().__init__(f"Repository {repo.working_dir} is dirty") -class GitRepositoryContext(Context): +class GitRepositoryContext(ContextBase): def __init__( self, name: str, @@ -57,3 +62,19 @@ def encapsulate(self) -> dict: def default_key(self) -> tuple[str, str]: return ("git", self.name) + + @classmethod + def default(cls) -> Callable[[CapsuleParams], GitRepositoryContext]: + def callback(params: CapsuleParams) -> GitRepositoryContext: + func_file_path = Path(inspect.getfile(params.func)) + repo = Repo(func_file_path.parent, search_parent_directories=True) + repo_name = Path(repo.working_dir).name + return cls( + name=Path(repo.working_dir).name, + path=Path(repo.working_dir), + diff_file=params.run_dir / f"{repo_name}.diff", + search_parent_directories=False, + allow_dirty=True, + ) + + return callback diff --git a/capsula/context/_platform.py b/capsula/_context/_platform.py similarity index 93% rename from capsula/context/_platform.py rename to capsula/_context/_platform.py index c365f5c1..d401fdfd 100644 --- a/capsula/context/_platform.py +++ b/capsula/_context/_platform.py @@ -1,9 +1,9 @@ import platform as pf -from ._base import Context +from ._base import ContextBase -class PlatformContext(Context): +class PlatformContext(ContextBase): def encapsulate(self) -> dict: return { "machine": pf.machine(), diff --git a/capsula/_decorator.py b/capsula/_decorator.py new file mode 100644 index 00000000..bb87b35a --- /dev/null +++ b/capsula/_decorator.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from functools import wraps +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Literal, Tuple, TypeVar, Union + +from capsula._reporter import ReporterBase +from capsula.encapsulator import Encapsulator + +from ._backport import ParamSpec +from ._context import ContextBase +from ._run import CapsuleParams, FuncInfo, Run +from ._watcher import WatcherBase + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ._backport import TypeAlias + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +_ContextInput: TypeAlias = Union[ + ContextBase, + Tuple[ContextBase, Tuple[str, ...]], + Callable[[Path, Callable], Union[ContextBase, Tuple[ContextBase, Tuple[str, ...]]]], +] +_WatcherInput: TypeAlias = Union[ + WatcherBase, + Tuple[WatcherBase, Tuple[str, ...]], + Callable[[Path, Callable], Union[WatcherBase, Tuple[WatcherBase, Tuple[str, ...]]]], +] +_ReporterInput: TypeAlias = Union[ReporterBase, Callable[[Path, Callable], ReporterBase]] + + +def capsule( # noqa: C901 + capsule_directory: Path | str | None = None, + pre_run_contexts: Sequence[_ContextInput] | None = None, + pre_run_reporters: Sequence[_ReporterInput] | None = None, + in_run_watchers: Sequence[_WatcherInput] | None = None, + post_run_contexts: Sequence[_ContextInput] | None = None, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + if capsule_directory is None: + raise NotImplementedError + capsule_directory = Path(capsule_directory) + + assert pre_run_contexts is not None + assert pre_run_reporters is not None + assert in_run_watchers is not None + assert post_run_contexts is not None + + def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]: + pre_run_enc = Encapsulator() + for cxt in pre_run_contexts: + if isinstance(cxt, ContextBase): + pre_run_enc.add_context(cxt) + elif isinstance(cxt, tuple): + pre_run_enc.add_context(cxt[0], key=cxt[1]) + else: + cxt_hydrated = cxt(capsule_directory, func) + if isinstance(cxt_hydrated, ContextBase): + pre_run_enc.add_context(cxt_hydrated) + elif isinstance(cxt_hydrated, tuple): + pre_run_enc.add_context(cxt_hydrated[0], key=cxt_hydrated[1]) + + @wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + capsule_directory.mkdir(parents=True, exist_ok=True) + pre_run_capsule = pre_run_enc.encapsulate() + for reporter in pre_run_reporters: + if isinstance(reporter, ReporterBase): + reporter.report(pre_run_capsule) + else: + reporter(capsule_directory, func).report(pre_run_capsule) + + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def watcher( + watcher: WatcherBase | Callable[[CapsuleParams], WatcherBase], +) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: + def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: + func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run + run = func_or_run if isinstance(func_or_run, Run) else Run(func) + run.add_watcher(watcher) + return run + + return decorator + + +def reporter( + reporter: ReporterBase | Callable[[CapsuleParams], ReporterBase], + mode: Literal["pre", "in", "post", "all"], +) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: + def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: + func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run + run = func_or_run if isinstance(func_or_run, Run) else Run(func) + run.add_reporter(reporter, mode=mode) + return run + + return decorator + + +def context( + context: ContextBase | Callable[[CapsuleParams], ContextBase], + mode: Literal["pre", "post", "all"], +) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: + def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: + func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run + run = func_or_run if isinstance(func_or_run, Run) else Run(func) + run.add_context(context, mode=mode) + return run + + return decorator + + +def run( + run_dir: Path | Callable[[FuncInfo], Path], +) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: + def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: + func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run + run = func_or_run if isinstance(func_or_run, Run) else Run(func) + run.set_run_dir(run_dir) + + return run + + return decorator diff --git a/capsula/_reporter/__init__.py b/capsula/_reporter/__init__.py new file mode 100644 index 00000000..fff6b94a --- /dev/null +++ b/capsula/_reporter/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["JsonDumpReporter", "ReporterBase"] +from ._base import ReporterBase +from ._json import JsonDumpReporter diff --git a/capsula/reporter/_base.py b/capsula/_reporter/_base.py similarity index 87% rename from capsula/reporter/_base.py rename to capsula/_reporter/_base.py index fe8c4c84..e1926e9d 100644 --- a/capsula/reporter/_base.py +++ b/capsula/_reporter/_base.py @@ -3,7 +3,7 @@ from capsula.encapsulator import Capsule -class Reporter(ABC): +class ReporterBase(ABC): @abstractmethod def report(self, capsule: Capsule) -> None: raise NotImplementedError diff --git a/capsula/reporter/_json.py b/capsula/_reporter/_json.py similarity index 78% rename from capsula/reporter/_json.py rename to capsula/_reporter/_json.py index 02416962..0661e17c 100644 --- a/capsula/reporter/_json.py +++ b/capsula/_reporter/_json.py @@ -1,8 +1,10 @@ from __future__ import annotations import logging +import traceback from datetime import timedelta from pathlib import Path +from types import TracebackType from typing import TYPE_CHECKING, Any, Callable, Optional import orjson @@ -12,30 +14,37 @@ if TYPE_CHECKING: from capsula.encapsulator import Capsule -from ._base import Reporter +from ._base import ReporterBase logger = logging.getLogger(__name__) def default_preset(obj: Any) -> Any: - # if isinstance(obj, timedelta): - # return str(obj) if isinstance(obj, Path): return str(obj) if isinstance(obj, timedelta): return str(obj) + if isinstance(obj, type) and issubclass(obj, BaseException): + return obj.__name__ + if isinstance(obj, Exception): + return str(obj) + if isinstance(obj, TracebackType): + return "".join(traceback.format_tb(obj)) raise TypeError -class JsonDumpReporter(Reporter): +class JsonDumpReporter(ReporterBase): def __init__( self, path: Path | str, *, default: Optional[Callable[[Any], Any]] = None, option: Optional[int] = None, + mkdir: bool = True, ) -> None: self.path = Path(path) + if mkdir: + self.path.parent.mkdir(parents=True, exist_ok=True) if default is None: self.default = default_preset diff --git a/capsula/_root.py b/capsula/_root.py new file mode 100644 index 00000000..bec1a441 --- /dev/null +++ b/capsula/_root.py @@ -0,0 +1,11 @@ +from typing import Any + +from .encapsulator import Encapsulator, _CapsuleItemKey + + +def record(key: _CapsuleItemKey, value: Any) -> None: + enc = Encapsulator.get_current() + if enc is None: + msg = "No active encapsulator found." + raise RuntimeError(msg) + enc.record(key, value) diff --git a/capsula/_run.py b/capsula/_run.py new file mode 100644 index 00000000..bf0c38c8 --- /dev/null +++ b/capsula/_run.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import queue +import threading +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar + +from pydantic import BaseModel + +from capsula._reporter import ReporterBase +from capsula.encapsulator import Encapsulator + +from ._backport import ParamSpec, Self +from ._context import ContextBase +from ._watcher import WatcherBase + +if TYPE_CHECKING: + from types import TracebackType + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class FuncInfo(BaseModel): + func: Callable + args: tuple + kwargs: dict + + +class CapsuleParams(FuncInfo): + run_dir: Path + phase: Literal["pre", "in", "post"] + + +class Run(Generic[_P, _T]): + _thread_local = threading.local() + + @classmethod + def _get_run_stack(cls) -> queue.LifoQueue[Self]: + if not hasattr(cls._thread_local, "run_stack"): + cls._thread_local.run_stack = queue.LifoQueue() + return cls._thread_local.run_stack + + @classmethod + def get_current(cls) -> Self | None: + try: + return cls._get_run_stack().queue[-1] + except IndexError: + return None + + def __init__(self, func: Callable[_P, _T]) -> None: + self.pre_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = [] + self.in_run_watcher_generators: list[Callable[[CapsuleParams], WatcherBase]] = [] + self.post_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = [] + + self.pre_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = [] + self.in_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = [] + self.post_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = [] + + self.func = func + self.run_dir_generator: Callable[[FuncInfo], Path] | None = None + + def add_context( + self, + context: ContextBase | Callable[[CapsuleParams], ContextBase], + *, + mode: Literal["pre", "post", "all"], + ) -> None: + def context_generator(params: CapsuleParams) -> ContextBase: + if isinstance(context, ContextBase): + return context + else: + return context(params) + + if mode == "pre": + self.pre_run_context_generators.append(context_generator) + elif mode == "post": + self.post_run_context_generators.append(context_generator) + elif mode == "all": + self.pre_run_context_generators.append(context_generator) + self.post_run_context_generators.append(context_generator) + else: + msg = f"mode must be one of 'pre', 'post', or 'all', not {mode}." + raise ValueError(msg) + + def add_watcher(self, watcher: WatcherBase | Callable[[CapsuleParams], WatcherBase]) -> None: + def watcher_generator(params: CapsuleParams) -> WatcherBase: + if isinstance(watcher, WatcherBase): + return watcher + else: + return watcher(params) + + self.in_run_watcher_generators.append(watcher_generator) + + def add_reporter( + self, + reporter: ReporterBase | Callable[[CapsuleParams], ReporterBase], + *, + mode: Literal["pre", "in", "post", "all"], + ) -> None: + def reporter_generator(params: CapsuleParams) -> ReporterBase: + if isinstance(reporter, ReporterBase): + return reporter + else: + return reporter(params) + + if mode == "pre": + self.pre_run_reporter_generators.append(reporter_generator) + elif mode == "in": + self.in_run_reporter_generators.append(reporter_generator) + elif mode == "post": + self.post_run_reporter_generators.append(reporter_generator) + elif mode == "all": + self.pre_run_reporter_generators.append(reporter_generator) + self.in_run_reporter_generators.append(reporter_generator) + self.post_run_reporter_generators.append(reporter_generator) + else: + msg = f"mode must be one of 'pre', 'in', 'post', or 'all', not {mode}." + raise ValueError(msg) + + def set_run_dir(self, run_dir: Path | Callable[[FuncInfo], Path]) -> None: + def run_dir_generator(params: FuncInfo) -> Path: + if isinstance(run_dir, Path): + return run_dir + else: + return run_dir(params) + + self.run_dir_generator = run_dir_generator + + def __enter__(self) -> Self: + self._get_run_stack().put(self) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self._get_run_stack().get(block=False) + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + func_info = FuncInfo(func=self.func, args=args, kwargs=kwargs) + if self.run_dir_generator is None: + msg = "run_dir_generator must be set before calling the function." + raise ValueError(msg) + run_dir = self.run_dir_generator(func_info) + run_dir.mkdir(parents=True, exist_ok=True) + params = CapsuleParams( + func=func_info.func, + args=func_info.args, + kwargs=func_info.kwargs, + run_dir=run_dir, + phase="pre", + ) + + pre_run_enc = Encapsulator() + for context_generator in self.pre_run_context_generators: + context = context_generator(params) + pre_run_enc.add_context(context) + pre_run_capsule = pre_run_enc.encapsulate() + for reporter_generator in self.pre_run_reporter_generators: + reporter = reporter_generator(params) + reporter.report(pre_run_capsule) + + params.phase = "in" + in_run_enc = Encapsulator() + for watcher_generator in self.in_run_watcher_generators: + watcher = watcher_generator(params) + in_run_enc.add_watcher(watcher) + + with self, in_run_enc, in_run_enc.watch(): + result = self.func(*args, **kwargs) + + in_run_capsule = in_run_enc.encapsulate() + for reporter_generator in self.in_run_reporter_generators: + reporter = reporter_generator(params) + reporter.report(in_run_capsule) + + params.phase = "post" + post_run_enc = Encapsulator() + for context_generator in self.post_run_context_generators: + context = context_generator(params) + post_run_enc.add_context(context) + post_run_capsule = post_run_enc.encapsulate() + for reporter_generator in self.post_run_reporter_generators: + reporter = reporter_generator(params) + reporter.report(post_run_capsule) + + return result diff --git a/capsula/_watcher/__init__.py b/capsula/_watcher/__init__.py new file mode 100644 index 00000000..32a2fa0c --- /dev/null +++ b/capsula/_watcher/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["WatcherBase", "TimeWatcher", "UncaughtExceptionWatcher"] +from ._base import WatcherBase +from ._exception import UncaughtExceptionWatcher +from ._time import TimeWatcher diff --git a/capsula/watcher/_base.py b/capsula/_watcher/_base.py similarity index 79% rename from capsula/watcher/_base.py rename to capsula/_watcher/_base.py index 49d9defb..7127eb5b 100644 --- a/capsula/watcher/_base.py +++ b/capsula/_watcher/_base.py @@ -6,10 +6,10 @@ if TYPE_CHECKING: from contextlib import AbstractContextManager -from capsula.capsule import CapsuleItem +from capsula._capsule import CapsuleItem -class Watcher(CapsuleItem): +class WatcherBase(CapsuleItem): @abstractmethod def watch(self) -> AbstractContextManager[None]: raise NotImplementedError diff --git a/capsula/watcher/_exception.py b/capsula/_watcher/_exception.py similarity index 93% rename from capsula/watcher/_exception.py rename to capsula/_watcher/_exception.py index e1dda665..19a13ed0 100644 --- a/capsula/watcher/_exception.py +++ b/capsula/_watcher/_exception.py @@ -9,12 +9,12 @@ from capsula.utils import ExceptionInfo -from ._base import Watcher +from ._base import WatcherBase logger = logging.getLogger(__name__) -class UncaughtExceptionWatcher(Watcher): +class UncaughtExceptionWatcher(WatcherBase): def __init__(self, name: str, *, base: type[BaseException] = Exception, reraise: bool = False) -> None: self.name = name self.base = base diff --git a/capsula/watcher/_time.py b/capsula/_watcher/_time.py similarity index 92% rename from capsula/watcher/_time.py rename to capsula/_watcher/_time.py index 35997ca5..c55e726d 100644 --- a/capsula/watcher/_time.py +++ b/capsula/_watcher/_time.py @@ -9,12 +9,12 @@ if TYPE_CHECKING: from collections.abc import Iterator -from ._base import Watcher +from ._base import WatcherBase logger = logging.getLogger(__name__) -class TimeWatcher(Watcher): +class TimeWatcher(WatcherBase): def __init__(self, name: str) -> None: self.name = name self.duration: timedelta | None = None diff --git a/capsula/capture.py b/capsula/capture.py index 1544e71e..7618a875 100644 --- a/capsula/capture.py +++ b/capsula/capture.py @@ -5,8 +5,8 @@ import logging import subprocess -from capsula._context import Context from capsula.config import CapsulaConfig +from capsula.context import Context logger = logging.getLogger(__name__) diff --git a/capsula/_context.py b/capsula/context.py similarity index 100% rename from capsula/_context.py rename to capsula/context.py diff --git a/capsula/context/_base.py b/capsula/context/_base.py deleted file mode 100644 index a5759c5a..00000000 --- a/capsula/context/_base.py +++ /dev/null @@ -1,5 +0,0 @@ -from capsula.capsule import CapsuleItem - - -class Context(CapsuleItem): - pass diff --git a/capsula/encapsulator.py b/capsula/encapsulator.py index c5478ac1..6030cb99 100644 --- a/capsula/encapsulator.py +++ b/capsula/encapsulator.py @@ -1,24 +1,24 @@ from __future__ import annotations import queue +import threading from collections import OrderedDict from collections.abc import Hashable from contextlib import AbstractContextManager from itertools import chain from typing import TYPE_CHECKING, Any, Generic, Tuple, TypeVar, Union -if TYPE_CHECKING: - from ._backport import TypeAlias +from capsula.utils import ExceptionInfo + +from ._capsule import Capsule +from ._context import ContextBase +from ._watcher import WatcherBase +from .exceptions import CapsulaError if TYPE_CHECKING: from types import TracebackType -from capsula.utils import ExceptionInfo - -from .capsule import Capsule -from .context import Context -from .exceptions import CapsulaError -from .watcher import Watcher + from ._backport import Self, TypeAlias _CapsuleItemKey: TypeAlias = Union[str, Tuple[str, ...]] @@ -28,7 +28,7 @@ def __init__(self, key: _CapsuleItemKey) -> None: super().__init__(f"Capsule item with key {key} already exists.") -class ObjectContext(Context): +class ObjectContext(ContextBase): def __init__(self, obj: Any) -> None: self.obj = obj @@ -37,7 +37,7 @@ def encapsulate(self) -> Any: _K = TypeVar("_K", bound=Hashable) -_V = TypeVar("_V", bound=Watcher) +_V = TypeVar("_V", bound=WatcherBase) class WatcherGroup(AbstractContextManager, Generic[_K, _V]): @@ -64,7 +64,7 @@ def __exit__( suppress_exception = False while not self.context_manager_stack.empty(): - cm = self.context_manager_stack.get() + cm = self.context_manager_stack.get(block=False) suppress = bool(cm.__exit__(exc_type, exc_value, traceback)) suppress_exception = suppress_exception or suppress @@ -77,11 +77,38 @@ def __exit__( class Encapsulator: + _thread_local = threading.local() + + @classmethod + def _get_context_stack(cls) -> queue.LifoQueue[Self]: + if not hasattr(cls._thread_local, "context_stack"): + cls._thread_local.context_stack = queue.LifoQueue() + return cls._thread_local.context_stack + + @classmethod + def get_current(cls) -> Self | None: + try: + return cls._get_context_stack().queue[-1] + except IndexError: + return None + def __init__(self) -> None: - self.contexts: OrderedDict[_CapsuleItemKey, Context] = OrderedDict() - self.watchers: OrderedDict[_CapsuleItemKey, Watcher] = OrderedDict() + self.contexts: OrderedDict[_CapsuleItemKey, ContextBase] = OrderedDict() + self.watchers: OrderedDict[_CapsuleItemKey, WatcherBase] = OrderedDict() + + def __enter__(self) -> Self: + self._get_context_stack().put(self) + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self._get_context_stack().get(block=False) - def add_context(self, context: Context, key: _CapsuleItemKey | None = None) -> None: + def add_context(self, context: ContextBase, key: _CapsuleItemKey | None = None) -> None: if key is None: key = context.default_key() if key in self.contexts or key in self.watchers: @@ -91,7 +118,7 @@ def add_context(self, context: Context, key: _CapsuleItemKey | None = None) -> N def record(self, key: _CapsuleItemKey, record: Any) -> None: self.add_context(ObjectContext(record), key) - def add_watcher(self, watcher: Watcher, key: _CapsuleItemKey | None = None) -> None: + def add_watcher(self, watcher: WatcherBase, key: _CapsuleItemKey | None = None) -> None: if key is None: key = watcher.default_key() if key in self.contexts or key in self.watchers: @@ -110,5 +137,5 @@ def encapsulate(self, *, abort_on_error: bool = False) -> Capsule: fails[key] = ExceptionInfo.from_exception(e) return Capsule(data, fails) - def watch(self) -> WatcherGroup[_CapsuleItemKey, Watcher]: + def watch(self) -> WatcherGroup[_CapsuleItemKey, WatcherBase]: return WatcherGroup(self.watchers) diff --git a/capsula/reporter/__init__.py b/capsula/reporter/__init__.py deleted file mode 100644 index 178bd6a0..00000000 --- a/capsula/reporter/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -__all__ = ["JsonDumpReporter"] -from ._json import JsonDumpReporter diff --git a/capsula/watcher/__init__.py b/capsula/watcher/__init__.py deleted file mode 100644 index aa9376d1..00000000 --- a/capsula/watcher/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -__all__ = ["Watcher", "TimeWatcher", "UncaughtExceptionWatcher"] -from ._base import Watcher -from ._exception import UncaughtExceptionWatcher -from ._time import TimeWatcher diff --git a/coverage/badge.svg b/coverage/badge.svg index ab10c4c0..000194b0 100644 --- a/coverage/badge.svg +++ b/coverage/badge.svg @@ -1 +1 @@ -coverage: 69.85%coverage69.85% \ No newline at end of file +coverage: 62.65%coverage62.65% \ No newline at end of file diff --git a/examples/decorator.py b/examples/decorator.py new file mode 100644 index 00000000..96f8bb91 --- /dev/null +++ b/examples/decorator.py @@ -0,0 +1,62 @@ +import logging +import random +from datetime import UTC, datetime +from pathlib import Path + +import orjson + +import capsula + +logger = logging.getLogger(__name__) + + +@capsula.run( + run_dir=lambda _: Path(__file__).parents[1] / "vault" / datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S"), +) +@capsula.context( + lambda params: capsula.FileContext( + Path(__file__).parents[1] / "pyproject.toml", + hash_algorithm="sha256", + copy_to=params.run_dir, + ), + mode="pre", +) +@capsula.context(capsula.GitRepositoryContext.default(), mode="pre") +@capsula.reporter( + lambda params: capsula.JsonDumpReporter( + params.run_dir / f"{params.phase}-run-report.json", + option=orjson.OPT_INDENT_2, + ), + mode="all", +) +@capsula.watcher(capsula.TimeWatcher("calculation_time")) +@capsula.context( + lambda params: capsula.FileContext( + Path(__file__).parent / "pi.txt", + hash_algorithm="sha256", + move_to=params.run_dir, + ), + mode="post", +) +def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None: + logger.info(f"Calculating pi with {n_samples} samples.") + logger.debug(f"Seed: {seed}") + random.seed(seed) + xs = (random.random() for _ in range(n_samples)) # noqa: S311 + ys = (random.random() for _ in range(n_samples)) # noqa: S311 + inside = sum(x * x + y * y <= 1.0 for x, y in zip(xs, ys)) + + capsula.record("inside", inside) + + pi_estimate = (4.0 * inside) / n_samples + logger.info(f"Pi estimate: {pi_estimate}") + capsula.record("pi_estimate", pi_estimate) + # raise CapsulaError("This is a test error.") + + with (Path(__file__).parent / "pi.txt").open("w") as output_file: + output_file.write(str(pi_estimate)) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + calculate_pi(n_samples=1_000) diff --git a/examples/enc_context_manager.py b/examples/enc_context_manager.py new file mode 100644 index 00000000..c443487f --- /dev/null +++ b/examples/enc_context_manager.py @@ -0,0 +1,51 @@ +import logging +import random +from datetime import UTC, datetime +from pathlib import Path + +import orjson + +import capsula + +logger = logging.getLogger(__name__) + +logging.basicConfig(level=logging.DEBUG) + + +def calc_pi(n_samples: int, seed: int) -> float: + random.seed(seed) + xs = (random.random() for _ in range(n_samples)) # noqa: S311 + ys = (random.random() for _ in range(n_samples)) # noqa: S311 + inside = sum(x * x + y * y <= 1.0 for x, y in zip(xs, ys)) + + capsula.record("inside", inside) + + pi_estimate = (4.0 * inside) / n_samples + logger.info(f"Pi estimate: {pi_estimate}") + capsula.record("pi_estimate", pi_estimate) + + return pi_estimate + + +def main(n_samples: int, seed: int) -> None: + # Define the run name and create the capsule directory + run_name = datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S") + capsule_directory = Path(__file__).parents[1] / "vault" / run_name + + with capsula.Encapsulator() as enc: + logger.info(f"Calculating pi with {n_samples} samples.") + logger.debug(f"Seed: {seed}") + + pi_estimate = calc_pi(n_samples, seed) + + with (Path(__file__).parent / "pi.txt").open("w") as output_file: + output_file.write(str(pi_estimate)) + + in_run_capsule = enc.encapsulate() + + in_run_reporter = capsula.JsonDumpReporter(capsule_directory / "in_run_report.json", option=orjson.OPT_INDENT_2) + in_run_reporter.report(in_run_capsule) + + +if __name__ == "__main__": + main(1000, 42) diff --git a/examples/low_level.py b/examples/low_level.py index 05d3950e..91cd20c1 100644 --- a/examples/low_level.py +++ b/examples/low_level.py @@ -5,18 +5,7 @@ import orjson -from capsula import Encapsulator -from capsula.context import ( - CommandContext, - CpuContext, - CwdContext, - EnvVarContext, - FileContext, - GitRepositoryContext, - PlatformContext, -) -from capsula.reporter import JsonDumpReporter -from capsula.watcher import TimeWatcher, UncaughtExceptionWatcher +import capsula logger = logging.getLogger(__name__) @@ -32,10 +21,10 @@ capsule_directory.mkdir(parents=True, exist_ok=True) # Create an encapsulator -pre_run_enc = Encapsulator() +pre_run_enc = capsula.Encapsulator() # Create a reporter -pre_run_reporter = JsonDumpReporter(capsule_directory / "pre_run_report.json", option=orjson.OPT_INDENT_2) +pre_run_reporter = capsula.JsonDumpReporter(capsule_directory / "pre_run_report.json", option=orjson.OPT_INDENT_2) # slack_reporter = SlackReporter( # webhook_url="https://hooks.slack.com/services/T01JZQZQZQZ/B01JZQZQZQZ/QQZQZQZQZQZQZQZQZQZQZQZ", # channel="test", @@ -45,7 +34,7 @@ # The order of the contexts is important. pre_run_enc.record("run_name", run_name) pre_run_enc.add_context( - GitRepositoryContext( + capsula.GitRepositoryContext( name="capsula", path=Path(__file__).parents[1], diff_file=capsule_directory / "capsula.diff", @@ -53,26 +42,30 @@ ), key=("git", "capsula"), ) -pre_run_enc.add_context(CpuContext()) -pre_run_enc.add_context(PlatformContext()) -pre_run_enc.add_context(CwdContext()) -pre_run_enc.add_context(EnvVarContext("HOME"), key=("env", "HOME")) -pre_run_enc.add_context(EnvVarContext("PATH")) # Default key will be used -pre_run_enc.add_context(CommandContext("poetry check --lock")) +pre_run_enc.add_context(capsula.CpuContext()) +pre_run_enc.add_context(capsula.PlatformContext()) +pre_run_enc.add_context(capsula.CwdContext()) +pre_run_enc.add_context(capsula.EnvVarContext("HOME"), key=("env", "HOME")) +pre_run_enc.add_context(capsula.EnvVarContext("PATH")) # Default key will be used +pre_run_enc.add_context(capsula.CommandContext("poetry check --lock")) # This will have a side effect -pre_run_enc.add_context(CommandContext("pip freeze --exclude-editable > requirements.txt")) +pre_run_enc.add_context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt")) pre_run_enc.add_context( - FileContext( + capsula.FileContext( Path(__file__).parents[1] / "requirements.txt", hash_algorithm="sha256", move_to=capsule_directory, ), ) pre_run_enc.add_context( - FileContext(Path(__file__).parents[1] / "pyproject.toml", hash_algorithm="sha256", copy_to=capsule_directory), + capsula.FileContext( + Path(__file__).parents[1] / "pyproject.toml", + hash_algorithm="sha256", + copy_to=capsule_directory, + ), ) pre_run_enc.add_context( - FileContext(Path(__file__).parents[1] / "poetry.lock", hash_algorithm="sha256", copy_to=capsule_directory), + capsula.FileContext(Path(__file__).parents[1] / "poetry.lock", hash_algorithm="sha256", copy_to=capsule_directory), ) pre_run_capsule = pre_run_enc.encapsulate() @@ -80,15 +73,15 @@ # slack_reporter.report(pre_run_capsule) # Actual calculation -in_run_enc = Encapsulator() -in_run_reporter = JsonDumpReporter(capsule_directory / "in_run_report.json", option=orjson.OPT_INDENT_2) +in_run_enc = capsula.Encapsulator() +in_run_reporter = capsula.JsonDumpReporter(capsule_directory / "in_run_report.json", option=orjson.OPT_INDENT_2) # The order matters. The first watcher will be the innermost one. # Record the time it takes to run the function. -in_run_enc.add_watcher(TimeWatcher("calculation_time")) +in_run_enc.add_watcher(capsula.TimeWatcher("calculation_time")) # Catch the exception raised by the encapsulated function. -in_run_enc.add_watcher(UncaughtExceptionWatcher("Exception", base=Exception, reraise=False)) +in_run_enc.add_watcher(capsula.UncaughtExceptionWatcher("Exception", base=Exception, reraise=False)) with in_run_enc.watch(): logger.info(f"Calculating pi with {N_SAMPLES} samples.") @@ -109,7 +102,7 @@ output_file.write(str(pi_estimate)) in_run_enc.add_context( - FileContext(Path(__file__).parent / "pi.txt", hash_algorithm="sha256", move_to=capsule_directory), + capsula.FileContext(Path(__file__).parent / "pi.txt", hash_algorithm="sha256", move_to=capsule_directory), ) in_run_capsule = in_run_enc.encapsulate()