diff --git a/README.md b/README.md
index 6d1360c1..55621e62 100644
--- a/README.md
+++ b/README.md
@@ -25,45 +25,30 @@ See the following Python script:
```python
import logging
import random
-from datetime import UTC, datetime
from pathlib import Path
-import orjson
-
import capsula
logger = logging.getLogger(__name__)
-@capsula.run(
- run_dir=lambda _: Path(__file__).parents[1] / "vault" / datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S"),
-)
-@capsula.context(
- lambda params: capsula.FileContext(
- Path(__file__).parents[1] / "pyproject.toml",
- hash_algorithm="sha256",
- copy_to=params.run_dir,
- ),
- mode="pre",
-)
-@capsula.context(capsula.GitRepositoryContext.default(), mode="pre")
-@capsula.reporter(
- lambda params: capsula.JsonDumpReporter(
- params.run_dir / f"{params.phase}-run-report.json",
- option=orjson.OPT_INDENT_2,
- ),
- mode="all",
-)
+@capsula.run()
+@capsula.reporter(capsula.JsonDumpReporter.default(), mode="all")
+@capsula.context(capsula.FileContext.default(Path(__file__).parent / "pi.txt", move=True), mode="post")
+@capsula.watcher(capsula.UncaughtExceptionWatcher("Exception"))
@capsula.watcher(capsula.TimeWatcher("calculation_time"))
-@capsula.context(
- lambda params: capsula.FileContext(
- Path(__file__).parent / "pi.txt",
- hash_algorithm="sha256",
- move_to=params.run_dir,
- ),
- mode="post",
-)
-def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
+@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "pyproject.toml", copy=True), mode="pre")
+@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "poetry.lock", copy=True), mode="pre")
+@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "requirements.txt", move=True), mode="pre")
+@capsula.context(capsula.GitRepositoryContext.default(), mode="pre")
+@capsula.context(capsula.CommandContext("poetry check --lock"), mode="pre")
+@capsula.context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt"), mode="pre")
+@capsula.context(capsula.EnvVarContext("HOME"), mode="pre")
+@capsula.context(capsula.EnvVarContext("PATH"), mode="pre")
+@capsula.context(capsula.CwdContext(), mode="pre")
+@capsula.context(capsula.CpuContext(), mode="pre")
+@capsula.pass_pre_run_capsule
+def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None:
logger.info(f"Calculating pi with {n_samples} samples.")
logger.debug(f"Seed: {seed}")
random.seed(seed)
@@ -76,13 +61,14 @@ def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
pi_estimate = (4.0 * inside) / n_samples
logger.info(f"Pi estimate: {pi_estimate}")
capsula.record("pi_estimate", pi_estimate)
+ logger.info(pre_run_capsule.data)
+ logger.info(capsula.current_run_name())
with (Path(__file__).parent / "pi.txt").open("w") as output_file:
- output_file.write(str(pi_estimate))
+ output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}")
if __name__ == "__main__":
- logging.basicConfig(level=logging.DEBUG)
calculate_pi(n_samples=1_000)
```
diff --git a/capsula/__init__.py b/capsula/__init__.py
index a04f2ce3..4f35fc90 100644
--- a/capsula/__init__.py
+++ b/capsula/__init__.py
@@ -19,6 +19,7 @@
"WatcherBase",
"__version__",
"context",
+ "current_run_name",
"get_capsule_dir",
"get_capsule_name",
"monitor",
@@ -43,7 +44,7 @@
)
from ._decorator import context, pass_pre_run_capsule, reporter, run, watcher
from ._reporter import JsonDumpReporter, ReporterBase
-from ._root import record
+from ._root import current_run_name, record
from ._run import Run
from ._version import __version__
from ._watcher import TimeWatcher, UncaughtExceptionWatcher, WatcherBase
diff --git a/capsula/_capsule.py b/capsula/_capsule.py
index 0a471290..092ef68a 100644
--- a/capsula/_capsule.py
+++ b/capsula/_capsule.py
@@ -1,14 +1,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Tuple, Union
if TYPE_CHECKING:
from collections.abc import Mapping
+ from capsula._decorator import CapsuleParams
from capsula.utils import ExceptionInfo
- from ._backport import TypeAlias
+ from ._backport import Self, TypeAlias
_ContextKey: TypeAlias = Union[str, Tuple[str, ...]]
@@ -32,3 +33,10 @@ def encapsulate(self) -> Any:
def default_key(self) -> str | tuple[str, ...]:
msg = f"{self.__class__.__name__}.default_key() is not implemented"
raise NotImplementedError(msg)
+
+ @classmethod
+ def default(cls, *args: Any, **kwargs: Any) -> Callable[[CapsuleParams], Self]:
+ def callback(params: CapsuleParams) -> Self: # type: ignore[type-var,misc] # noqa: ARG001
+ return cls(*args, **kwargs)
+
+ return callback
diff --git a/capsula/_context/_command.py b/capsula/_context/_command.py
index 52e7c004..59c1367b 100644
--- a/capsula/_context/_command.py
+++ b/capsula/_context/_command.py
@@ -13,13 +13,21 @@
class CommandContext(ContextBase):
- def __init__(self, command: str, cwd: Path | None = None) -> None:
+ def __init__(self, command: str, *, cwd: Path | None = None, check: bool = False) -> None:
self.command = command
self.cwd = cwd
+ self.check = check
def encapsulate(self) -> dict:
logger.debug(f"Running command: {self.command}")
- output = subprocess.run(self.command, shell=True, text=True, capture_output=True, cwd=self.cwd, check=False) # noqa: S602
+ output = subprocess.run(
+ self.command,
+ shell=True, # noqa: S602
+ text=True,
+ capture_output=True,
+ cwd=self.cwd,
+ check=self.check,
+ )
logger.debug(f"Ran command: {self.command}. Result: {output}")
return {
"command": self.command,
diff --git a/capsula/_context/_file.py b/capsula/_context/_file.py
index daebad51..d4083878 100644
--- a/capsula/_context/_file.py
+++ b/capsula/_context/_file.py
@@ -1,14 +1,18 @@
from __future__ import annotations
import logging
+import warnings
from pathlib import Path
from shutil import copyfile, move
-from typing import Iterable
+from typing import TYPE_CHECKING, Callable, Iterable
from capsula._backport import file_digest
from ._base import ContextBase
+if TYPE_CHECKING:
+ from capsula._decorator import CapsuleParams
+
logger = logging.getLogger(__name__)
@@ -20,7 +24,7 @@ def __init__(
path: Path | str,
*,
compute_hash: bool = True,
- hash_algorithm: str | None,
+ hash_algorithm: str | None = None,
copy_to: Iterable[Path | str] | Path | str | None = None,
move_to: Path | str | None = None,
) -> None:
@@ -67,3 +71,29 @@ def encapsulate(self) -> dict:
def default_key(self) -> tuple[str, str]:
return ("file", str(self.path))
+
+ @classmethod
+ def default(
+ cls,
+ path: Path | str,
+ *,
+ compute_hash: bool = True,
+ hash_algorithm: str | None = None,
+ copy: bool = False,
+ move: bool = False,
+ ) -> Callable[[CapsuleParams], FileContext]:
+ if copy and move:
+ warnings.warn("Both copy and move are True. Only move will be performed.", UserWarning, stacklevel=2)
+ move = True
+ copy = False
+
+ def callback(params: CapsuleParams) -> FileContext:
+ return cls(
+ path=path,
+ compute_hash=compute_hash,
+ hash_algorithm=hash_algorithm,
+ copy_to=params.run_dir if copy else None,
+ move_to=params.run_dir if move else None,
+ )
+
+ return callback
diff --git a/capsula/_context/_git.py b/capsula/_context/_git.py
index 0c53216d..94ff405e 100644
--- a/capsula/_context/_git.py
+++ b/capsula/_context/_git.py
@@ -64,17 +64,22 @@ def default_key(self) -> tuple[str, str]:
return ("git", self.name)
@classmethod
- def default(cls) -> Callable[[CapsuleParams], GitRepositoryContext]:
+ def default(
+ cls,
+ name: str | None = None,
+ *,
+ allow_dirty: bool | None = None,
+ ) -> Callable[[CapsuleParams], GitRepositoryContext]:
def callback(params: CapsuleParams) -> GitRepositoryContext:
func_file_path = Path(inspect.getfile(params.func))
repo = Repo(func_file_path.parent, search_parent_directories=True)
repo_name = Path(repo.working_dir).name
return cls(
- name=Path(repo.working_dir).name,
+ name=Path(repo.working_dir).name if name is None else name,
path=Path(repo.working_dir),
diff_file=params.run_dir / f"{repo_name}.diff",
search_parent_directories=False,
- allow_dirty=True,
+ allow_dirty=True if allow_dirty is None else allow_dirty,
)
return callback
diff --git a/capsula/_decorator.py b/capsula/_decorator.py
index 0583e5c3..6d782542 100644
--- a/capsula/_decorator.py
+++ b/capsula/_decorator.py
@@ -1,13 +1,18 @@
from __future__ import annotations
+import inspect
+from datetime import datetime, timezone
+from pathlib import Path
+from random import choices
+from string import ascii_letters, digits
from typing import TYPE_CHECKING, Callable, Literal, TypeVar
+from capsula.utils import search_for_project_root
+
from ._backport import Concatenate, ParamSpec
from ._run import CapsuleParams, FuncInfo, Run
if TYPE_CHECKING:
- from pathlib import Path
-
from ._capsule import Capsule
from ._context import ContextBase
from ._reporter import ReporterBase
@@ -53,8 +58,17 @@ def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
def run(
- run_dir: Path | Callable[[FuncInfo], Path],
+ run_dir: Path | Callable[[FuncInfo], Path] | None = None,
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
+ def _default_run_dir_generator(func_info: FuncInfo) -> Path:
+ project_root = search_for_project_root(Path(inspect.getfile(func_info.func)))
+ random_suffix = "".join(choices(ascii_letters + digits, k=4)) # noqa: S311
+ datetime_str = datetime.now(timezone.utc).astimezone().strftime(r"%Y%m%d_%H%M%S")
+ dir_name = f"{func_info.func.__name__}_{datetime_str}_{random_suffix}"
+ return project_root / "vault" / dir_name
+
+ run_dir = _default_run_dir_generator if run_dir is None else run_dir
+
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.set_run_dir(run_dir)
diff --git a/capsula/_reporter/_json.py b/capsula/_reporter/_json.py
index 0661e17c..84131047 100644
--- a/capsula/_reporter/_json.py
+++ b/capsula/_reporter/_json.py
@@ -12,6 +12,7 @@
from capsula.utils import to_nested_dict
if TYPE_CHECKING:
+ from capsula._decorator import CapsuleParams
from capsula.encapsulator import Capsule
from ._base import ReporterBase
@@ -47,7 +48,7 @@ def __init__(
self.path.parent.mkdir(parents=True, exist_ok=True)
if default is None:
- self.default = default_preset
+ self.default_for_encoder = default_preset
else:
def _default(obj: Any) -> Any:
@@ -56,7 +57,7 @@ def _default(obj: Any) -> Any:
except TypeError:
return default(obj)
- self.default = _default
+ self.default_for_encoder = _default
self.option = option
@@ -72,5 +73,19 @@ def _str_to_tuple(s: str | tuple[str, ...]) -> tuple[str, ...]:
if capsule.fails:
nested_data["__fails"] = to_nested_dict({_str_to_tuple(k): v for k, v in capsule.fails.items()})
- json_bytes = orjson.dumps(nested_data, default=self.default, option=self.option)
+ json_bytes = orjson.dumps(nested_data, default=self.default_for_encoder, option=self.option)
self.path.write_bytes(json_bytes)
+
+ @classmethod
+ def default(
+ cls,
+ *,
+ option: Optional[int] = None,
+ ) -> Callable[[CapsuleParams], JsonDumpReporter]:
+ def callback(params: CapsuleParams) -> JsonDumpReporter:
+ return cls(
+ params.run_dir / f"{params.phase}-run-report.json",
+ option=orjson.OPT_INDENT_2 if option is None else option,
+ )
+
+ return callback
diff --git a/capsula/_root.py b/capsula/_root.py
index bec1a441..a59017e5 100644
--- a/capsula/_root.py
+++ b/capsula/_root.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
from typing import Any
+from ._run import Run
from .encapsulator import Encapsulator, _CapsuleItemKey
@@ -9,3 +12,14 @@ def record(key: _CapsuleItemKey, value: Any) -> None:
msg = "No active encapsulator found."
raise RuntimeError(msg)
enc.record(key, value)
+
+
+def current_run_name() -> str:
+ run: Run | None = Run.get_current()
+ if run is None:
+ msg = "No active run found."
+ raise RuntimeError(msg)
+ if run.run_dir is None:
+ msg = "No active run directory found."
+ raise RuntimeError(msg)
+ return run.run_dir.name
diff --git a/capsula/_run.py b/capsula/_run.py
index 0942c349..1ba404b7 100644
--- a/capsula/_run.py
+++ b/capsula/_run.py
@@ -71,6 +71,7 @@ def __init__(self, func, *, pass_pre_run_capsule: bool = False) -> None:
self.func: Callable[_P, _T] | Callable[Concatenate[Capsule, _P], _T] = func
self.run_dir_generator: Callable[[FuncInfo], Path] | None = None
+ self.run_dir: Path | None = None
def add_context(
self,
@@ -156,13 +157,13 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
if self.run_dir_generator is None:
msg = "run_dir_generator must be set before calling the function."
raise ValueError(msg)
- run_dir = self.run_dir_generator(func_info)
- run_dir.mkdir(parents=True, exist_ok=True)
+ self.run_dir = self.run_dir_generator(func_info)
+ self.run_dir.mkdir(parents=True, exist_ok=True)
params = CapsuleParams(
func=func_info.func,
args=func_info.args,
kwargs=func_info.kwargs,
- run_dir=run_dir,
+ run_dir=self.run_dir,
phase="pre",
)
diff --git a/capsula/utils.py b/capsula/utils.py
index 75257752..2ceec77e 100644
--- a/capsula/utils.py
+++ b/capsula/utils.py
@@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any, Hashable
if TYPE_CHECKING:
+ from pathlib import Path
from types import TracebackType
@@ -71,3 +72,21 @@ def to_nested_dict(flat_dict: Mapping[Sequence[Hashable], Any]) -> dict[Hashable
else:
nested_dict[key[0]] = to_nested_dict({key[1:]: value})
return nested_dict
+
+
+def search_for_project_root(start: Path) -> Path:
+ """Search for the project root directory by looking for pyproject.toml.
+
+ Args:
+ start: The start directory to search.
+
+ Returns:
+ The project root directory.
+
+ """
+ if (start / "pyproject.toml").exists():
+ return start
+ if start == start.parent:
+ msg = "Project root not found."
+ raise FileNotFoundError(msg)
+ return search_for_project_root(start.parent)
diff --git a/coverage/badge.svg b/coverage/badge.svg
index 806459b4..7ea26315 100644
--- a/coverage/badge.svg
+++ b/coverage/badge.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/examples/decorator.py b/examples/decorator.py
index 992f2300..6341ee43 100644
--- a/examples/decorator.py
+++ b/examples/decorator.py
@@ -1,43 +1,27 @@
import logging
import random
-from datetime import UTC, datetime
from pathlib import Path
-import orjson
-
import capsula
logger = logging.getLogger(__name__)
-@capsula.run(
- run_dir=lambda _: Path(__file__).parents[1] / "vault" / datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S"),
-)
-@capsula.context(
- lambda params: capsula.FileContext(
- Path(__file__).parents[1] / "pyproject.toml",
- hash_algorithm="sha256",
- copy_to=params.run_dir,
- ),
- mode="pre",
-)
-@capsula.context(capsula.GitRepositoryContext.default(), mode="pre")
-@capsula.reporter(
- lambda params: capsula.JsonDumpReporter(
- params.run_dir / f"{params.phase}-run-report.json",
- option=orjson.OPT_INDENT_2,
- ),
- mode="all",
-)
+@capsula.run()
+@capsula.reporter(capsula.JsonDumpReporter.default(), mode="all")
+@capsula.context(capsula.FileContext.default(Path(__file__).parent / "pi.txt", move=True), mode="post")
+@capsula.watcher(capsula.UncaughtExceptionWatcher("Exception"))
@capsula.watcher(capsula.TimeWatcher("calculation_time"))
-@capsula.context(
- lambda params: capsula.FileContext(
- Path(__file__).parent / "pi.txt",
- hash_algorithm="sha256",
- move_to=params.run_dir,
- ),
- mode="post",
-)
+@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "pyproject.toml", copy=True), mode="pre")
+@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "poetry.lock", copy=True), mode="pre")
+@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "requirements.txt", move=True), mode="pre")
+@capsula.context(capsula.GitRepositoryContext.default("capsula"), mode="pre")
+@capsula.context(capsula.CommandContext("poetry check --lock"), mode="pre")
+@capsula.context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt"), mode="pre")
+@capsula.context(capsula.EnvVarContext("HOME"), mode="pre")
+@capsula.context(capsula.EnvVarContext("PATH"), mode="pre")
+@capsula.context(capsula.CwdContext(), mode="pre")
+@capsula.context(capsula.CpuContext(), mode="pre")
@capsula.pass_pre_run_capsule
def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None:
logger.info(f"Calculating pi with {n_samples} samples.")
@@ -54,6 +38,7 @@ def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, se
capsula.record("pi_estimate", pi_estimate)
# raise CapsulaError("This is a test error.")
logger.info(pre_run_capsule.data)
+ logger.info(capsula.current_run_name())
with (Path(__file__).parent / "pi.txt").open("w") as output_file:
output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}")