diff --git a/capsula/__init__.py b/capsula/__init__.py
index b404ae35..a04f2ce3 100644
--- a/capsula/__init__.py
+++ b/capsula/__init__.py
@@ -1,6 +1,7 @@
__all__ = [
"CapsulaConfigurationError",
"CapsulaError",
+ "Capsule",
"CommandContext",
"ContextBase",
"CpuContext",
@@ -21,6 +22,7 @@
"get_capsule_dir",
"get_capsule_name",
"monitor",
+ "pass_pre_run_capsule",
"record",
"reporter",
"run",
@@ -28,6 +30,7 @@
"set_capsule_name",
"watcher",
]
+from ._capsule import Capsule
from ._context import (
CommandContext,
ContextBase,
@@ -38,7 +41,7 @@
GitRepositoryContext,
PlatformContext,
)
-from ._decorator import context, reporter, run, watcher
+from ._decorator import context, pass_pre_run_capsule, reporter, run, watcher
from ._reporter import JsonDumpReporter, ReporterBase
from ._root import record
from ._run import Run
diff --git a/capsula/_backport.py b/capsula/_backport.py
index 913da585..a7c1f256 100644
--- a/capsula/_backport.py
+++ b/capsula/_backport.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-__all__ = ["ParamSpec", "Self", "TypeAlias", "file_digest"]
+__all__ = ["Concatenate", "ParamSpec", "Self", "TypeAlias", "file_digest"]
import hashlib
import sys
@@ -12,9 +12,9 @@
from typing_extensions import Self
if sys.version_info >= (3, 10):
- from typing import ParamSpec, TypeAlias
+ from typing import Concatenate, ParamSpec, TypeAlias
else:
- from typing_extensions import ParamSpec, TypeAlias
+ from typing_extensions import Concatenate, ParamSpec, TypeAlias
if sys.version_info >= (3, 11):
diff --git a/capsula/_decorator.py b/capsula/_decorator.py
index 0479fab0..0583e5c3 100644
--- a/capsula/_decorator.py
+++ b/capsula/_decorator.py
@@ -2,15 +2,15 @@
from typing import TYPE_CHECKING, Callable, Literal, TypeVar
-from ._backport import ParamSpec
+from ._backport import Concatenate, ParamSpec
from ._run import CapsuleParams, FuncInfo, Run
if TYPE_CHECKING:
from pathlib import Path
- from capsula._reporter import ReporterBase
-
+ from ._capsule import Capsule
from ._context import ContextBase
+ from ._reporter import ReporterBase
from ._watcher import WatcherBase
_P = ParamSpec("_P")
@@ -21,8 +21,7 @@ def watcher(
watcher: WatcherBase | Callable[[CapsuleParams], WatcherBase],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
- func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
- run = func_or_run if isinstance(func_or_run, Run) else Run(func)
+ run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.add_watcher(watcher)
return run
@@ -34,8 +33,7 @@ def reporter(
mode: Literal["pre", "in", "post", "all"],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
- func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
- run = func_or_run if isinstance(func_or_run, Run) else Run(func)
+ run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.add_reporter(reporter, mode=mode)
return run
@@ -47,8 +45,7 @@ def context(
mode: Literal["pre", "post", "all"],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
- func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
- run = func_or_run if isinstance(func_or_run, Run) else Run(func)
+ run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.add_context(context, mode=mode)
return run
@@ -59,10 +56,13 @@ def run(
run_dir: Path | Callable[[FuncInfo], Path],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
- func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
- run = func_or_run if isinstance(func_or_run, Run) else Run(func)
+ run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.set_run_dir(run_dir)
return run
return decorator
+
+
+def pass_pre_run_capsule(func: Callable[Concatenate[Capsule, _P], _T]) -> Run[_P, _T]:
+ return Run(func, pass_pre_run_capsule=True)
diff --git a/capsula/_run.py b/capsula/_run.py
index bf0c38c8..0942c349 100644
--- a/capsula/_run.py
+++ b/capsula/_run.py
@@ -3,20 +3,22 @@
import queue
import threading
from pathlib import Path
-from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar
+from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar, overload
from pydantic import BaseModel
from capsula._reporter import ReporterBase
from capsula.encapsulator import Encapsulator
-from ._backport import ParamSpec, Self
+from ._backport import Concatenate, ParamSpec, Self
from ._context import ContextBase
from ._watcher import WatcherBase
if TYPE_CHECKING:
from types import TracebackType
+ from ._capsule import Capsule
+
_P = ParamSpec("_P")
_T = TypeVar("_T")
@@ -48,7 +50,15 @@ def get_current(cls) -> Self | None:
except IndexError:
return None
- def __init__(self, func: Callable[_P, _T]) -> None:
+ @overload
+ def __init__(self, func: Callable[_P, _T], *, pass_pre_run_capsule: Literal[False] = False) -> None:
+ ...
+
+ @overload
+ def __init__(self, func: Callable[Concatenate[Capsule, _P], _T], *, pass_pre_run_capsule: Literal[True]) -> None:
+ ...
+
+ def __init__(self, func, *, pass_pre_run_capsule: bool = False) -> None:
self.pre_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = []
self.in_run_watcher_generators: list[Callable[[CapsuleParams], WatcherBase]] = []
self.post_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = []
@@ -57,7 +67,9 @@ def __init__(self, func: Callable[_P, _T]) -> None:
self.in_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = []
self.post_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = []
- self.func = func
+ self.pass_pre_run_capsule: bool = pass_pre_run_capsule
+ self.func: Callable[_P, _T] | Callable[Concatenate[Capsule, _P], _T] = func
+
self.run_dir_generator: Callable[[FuncInfo], Path] | None = None
def add_context(
@@ -170,7 +182,10 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
in_run_enc.add_watcher(watcher)
with self, in_run_enc, in_run_enc.watch():
- result = self.func(*args, **kwargs)
+ if self.pass_pre_run_capsule:
+ result = self.func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type]
+ else:
+ result = self.func(*args, **kwargs)
in_run_capsule = in_run_enc.encapsulate()
for reporter_generator in self.in_run_reporter_generators:
diff --git a/coverage/badge.svg b/coverage/badge.svg
index 8f2df76d..806459b4 100644
--- a/coverage/badge.svg
+++ b/coverage/badge.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/examples/decorator.py b/examples/decorator.py
index 96f8bb91..992f2300 100644
--- a/examples/decorator.py
+++ b/examples/decorator.py
@@ -38,7 +38,8 @@
),
mode="post",
)
-def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
+@capsula.pass_pre_run_capsule
+def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None:
logger.info(f"Calculating pi with {n_samples} samples.")
logger.debug(f"Seed: {seed}")
random.seed(seed)
@@ -52,9 +53,10 @@ def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
logger.info(f"Pi estimate: {pi_estimate}")
capsula.record("pi_estimate", pi_estimate)
# raise CapsulaError("This is a test error.")
+ logger.info(pre_run_capsule.data)
with (Path(__file__).parent / "pi.txt").open("w") as output_file:
- output_file.write(str(pi_estimate))
+ output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}")
if __name__ == "__main__":