diff --git a/capsula/__init__.py b/capsula/__init__.py index b404ae35..a04f2ce3 100644 --- a/capsula/__init__.py +++ b/capsula/__init__.py @@ -1,6 +1,7 @@ __all__ = [ "CapsulaConfigurationError", "CapsulaError", + "Capsule", "CommandContext", "ContextBase", "CpuContext", @@ -21,6 +22,7 @@ "get_capsule_dir", "get_capsule_name", "monitor", + "pass_pre_run_capsule", "record", "reporter", "run", @@ -28,6 +30,7 @@ "set_capsule_name", "watcher", ] +from ._capsule import Capsule from ._context import ( CommandContext, ContextBase, @@ -38,7 +41,7 @@ GitRepositoryContext, PlatformContext, ) -from ._decorator import context, reporter, run, watcher +from ._decorator import context, pass_pre_run_capsule, reporter, run, watcher from ._reporter import JsonDumpReporter, ReporterBase from ._root import record from ._run import Run diff --git a/capsula/_backport.py b/capsula/_backport.py index 913da585..a7c1f256 100644 --- a/capsula/_backport.py +++ b/capsula/_backport.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["ParamSpec", "Self", "TypeAlias", "file_digest"] +__all__ = ["Concatenate", "ParamSpec", "Self", "TypeAlias", "file_digest"] import hashlib import sys @@ -12,9 +12,9 @@ from typing_extensions import Self if sys.version_info >= (3, 10): - from typing import ParamSpec, TypeAlias + from typing import Concatenate, ParamSpec, TypeAlias else: - from typing_extensions import ParamSpec, TypeAlias + from typing_extensions import Concatenate, ParamSpec, TypeAlias if sys.version_info >= (3, 11): diff --git a/capsula/_decorator.py b/capsula/_decorator.py index 0479fab0..0583e5c3 100644 --- a/capsula/_decorator.py +++ b/capsula/_decorator.py @@ -2,15 +2,15 @@ from typing import TYPE_CHECKING, Callable, Literal, TypeVar -from ._backport import ParamSpec +from ._backport import Concatenate, ParamSpec from ._run import CapsuleParams, FuncInfo, Run if TYPE_CHECKING: from pathlib import Path - from capsula._reporter import ReporterBase - + from ._capsule import Capsule from ._context import ContextBase + from ._reporter import ReporterBase from ._watcher import WatcherBase _P = ParamSpec("_P") @@ -21,8 +21,7 @@ def watcher( watcher: WatcherBase | Callable[[CapsuleParams], WatcherBase], ) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: - func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run - run = func_or_run if isinstance(func_or_run, Run) else Run(func) + run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run) run.add_watcher(watcher) return run @@ -34,8 +33,7 @@ def reporter( mode: Literal["pre", "in", "post", "all"], ) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: - func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run - run = func_or_run if isinstance(func_or_run, Run) else Run(func) + run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run) run.add_reporter(reporter, mode=mode) return run @@ -47,8 +45,7 @@ def context( mode: Literal["pre", "post", "all"], ) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: - func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run - run = func_or_run if isinstance(func_or_run, Run) else Run(func) + run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run) run.add_context(context, mode=mode) return run @@ -59,10 +56,13 @@ def run( run_dir: Path | Callable[[FuncInfo], Path], ) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: - func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run - run = func_or_run if isinstance(func_or_run, Run) else Run(func) + run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run) run.set_run_dir(run_dir) return run return decorator + + +def pass_pre_run_capsule(func: Callable[Concatenate[Capsule, _P], _T]) -> Run[_P, _T]: + return Run(func, pass_pre_run_capsule=True) diff --git a/capsula/_run.py b/capsula/_run.py index bf0c38c8..0942c349 100644 --- a/capsula/_run.py +++ b/capsula/_run.py @@ -3,20 +3,22 @@ import queue import threading from pathlib import Path -from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar +from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar, overload from pydantic import BaseModel from capsula._reporter import ReporterBase from capsula.encapsulator import Encapsulator -from ._backport import ParamSpec, Self +from ._backport import Concatenate, ParamSpec, Self from ._context import ContextBase from ._watcher import WatcherBase if TYPE_CHECKING: from types import TracebackType + from ._capsule import Capsule + _P = ParamSpec("_P") _T = TypeVar("_T") @@ -48,7 +50,15 @@ def get_current(cls) -> Self | None: except IndexError: return None - def __init__(self, func: Callable[_P, _T]) -> None: + @overload + def __init__(self, func: Callable[_P, _T], *, pass_pre_run_capsule: Literal[False] = False) -> None: + ... + + @overload + def __init__(self, func: Callable[Concatenate[Capsule, _P], _T], *, pass_pre_run_capsule: Literal[True]) -> None: + ... + + def __init__(self, func, *, pass_pre_run_capsule: bool = False) -> None: self.pre_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = [] self.in_run_watcher_generators: list[Callable[[CapsuleParams], WatcherBase]] = [] self.post_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = [] @@ -57,7 +67,9 @@ def __init__(self, func: Callable[_P, _T]) -> None: self.in_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = [] self.post_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = [] - self.func = func + self.pass_pre_run_capsule: bool = pass_pre_run_capsule + self.func: Callable[_P, _T] | Callable[Concatenate[Capsule, _P], _T] = func + self.run_dir_generator: Callable[[FuncInfo], Path] | None = None def add_context( @@ -170,7 +182,10 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: in_run_enc.add_watcher(watcher) with self, in_run_enc, in_run_enc.watch(): - result = self.func(*args, **kwargs) + if self.pass_pre_run_capsule: + result = self.func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type] + else: + result = self.func(*args, **kwargs) in_run_capsule = in_run_enc.encapsulate() for reporter_generator in self.in_run_reporter_generators: diff --git a/coverage/badge.svg b/coverage/badge.svg index 8f2df76d..806459b4 100644 --- a/coverage/badge.svg +++ b/coverage/badge.svg @@ -1 +1 @@ -coverage: 46.89%coverage46.89% \ No newline at end of file +coverage: 47.24%coverage47.24% \ No newline at end of file diff --git a/examples/decorator.py b/examples/decorator.py index 96f8bb91..992f2300 100644 --- a/examples/decorator.py +++ b/examples/decorator.py @@ -38,7 +38,8 @@ ), mode="post", ) -def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None: +@capsula.pass_pre_run_capsule +def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None: logger.info(f"Calculating pi with {n_samples} samples.") logger.debug(f"Seed: {seed}") random.seed(seed) @@ -52,9 +53,10 @@ def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None: logger.info(f"Pi estimate: {pi_estimate}") capsula.record("pi_estimate", pi_estimate) # raise CapsulaError("This is a test error.") + logger.info(pre_run_capsule.data) with (Path(__file__).parent / "pi.txt").open("w") as output_file: - output_file.write(str(pi_estimate)) + output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}") if __name__ == "__main__":