diff --git a/README.md b/README.md index 6d1360c1..55621e62 100644 --- a/README.md +++ b/README.md @@ -25,45 +25,30 @@ See the following Python script: ```python import logging import random -from datetime import UTC, datetime from pathlib import Path -import orjson - import capsula logger = logging.getLogger(__name__) -@capsula.run( - run_dir=lambda _: Path(__file__).parents[1] / "vault" / datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S"), -) -@capsula.context( - lambda params: capsula.FileContext( - Path(__file__).parents[1] / "pyproject.toml", - hash_algorithm="sha256", - copy_to=params.run_dir, - ), - mode="pre", -) -@capsula.context(capsula.GitRepositoryContext.default(), mode="pre") -@capsula.reporter( - lambda params: capsula.JsonDumpReporter( - params.run_dir / f"{params.phase}-run-report.json", - option=orjson.OPT_INDENT_2, - ), - mode="all", -) +@capsula.run() +@capsula.reporter(capsula.JsonDumpReporter.default(), mode="all") +@capsula.context(capsula.FileContext.default(Path(__file__).parent / "pi.txt", move=True), mode="post") +@capsula.watcher(capsula.UncaughtExceptionWatcher("Exception")) @capsula.watcher(capsula.TimeWatcher("calculation_time")) -@capsula.context( - lambda params: capsula.FileContext( - Path(__file__).parent / "pi.txt", - hash_algorithm="sha256", - move_to=params.run_dir, - ), - mode="post", -) -def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None: +@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "pyproject.toml", copy=True), mode="pre") +@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "poetry.lock", copy=True), mode="pre") +@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "requirements.txt", move=True), mode="pre") +@capsula.context(capsula.GitRepositoryContext.default(), mode="pre") +@capsula.context(capsula.CommandContext("poetry check --lock"), mode="pre") +@capsula.context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt"), mode="pre") +@capsula.context(capsula.EnvVarContext("HOME"), mode="pre") +@capsula.context(capsula.EnvVarContext("PATH"), mode="pre") +@capsula.context(capsula.CwdContext(), mode="pre") +@capsula.context(capsula.CpuContext(), mode="pre") +@capsula.pass_pre_run_capsule +def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None: logger.info(f"Calculating pi with {n_samples} samples.") logger.debug(f"Seed: {seed}") random.seed(seed) @@ -76,13 +61,14 @@ def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None: pi_estimate = (4.0 * inside) / n_samples logger.info(f"Pi estimate: {pi_estimate}") capsula.record("pi_estimate", pi_estimate) + logger.info(pre_run_capsule.data) + logger.info(capsula.current_run_name()) with (Path(__file__).parent / "pi.txt").open("w") as output_file: - output_file.write(str(pi_estimate)) + output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}") if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) calculate_pi(n_samples=1_000) ``` diff --git a/capsula/__init__.py b/capsula/__init__.py index a04f2ce3..4f35fc90 100644 --- a/capsula/__init__.py +++ b/capsula/__init__.py @@ -19,6 +19,7 @@ "WatcherBase", "__version__", "context", + "current_run_name", "get_capsule_dir", "get_capsule_name", "monitor", @@ -43,7 +44,7 @@ ) from ._decorator import context, pass_pre_run_capsule, reporter, run, watcher from ._reporter import JsonDumpReporter, ReporterBase -from ._root import record +from ._root import current_run_name, record from ._run import Run from ._version import __version__ from ._watcher import TimeWatcher, UncaughtExceptionWatcher, WatcherBase diff --git a/capsula/_capsule.py b/capsula/_capsule.py index 0a471290..092ef68a 100644 --- a/capsula/_capsule.py +++ b/capsula/_capsule.py @@ -1,14 +1,15 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Tuple, Union if TYPE_CHECKING: from collections.abc import Mapping + from capsula._decorator import CapsuleParams from capsula.utils import ExceptionInfo - from ._backport import TypeAlias + from ._backport import Self, TypeAlias _ContextKey: TypeAlias = Union[str, Tuple[str, ...]] @@ -32,3 +33,10 @@ def encapsulate(self) -> Any: def default_key(self) -> str | tuple[str, ...]: msg = f"{self.__class__.__name__}.default_key() is not implemented" raise NotImplementedError(msg) + + @classmethod + def default(cls, *args: Any, **kwargs: Any) -> Callable[[CapsuleParams], Self]: + def callback(params: CapsuleParams) -> Self: # type: ignore[type-var,misc] # noqa: ARG001 + return cls(*args, **kwargs) + + return callback diff --git a/capsula/_context/_command.py b/capsula/_context/_command.py index 52e7c004..59c1367b 100644 --- a/capsula/_context/_command.py +++ b/capsula/_context/_command.py @@ -13,13 +13,21 @@ class CommandContext(ContextBase): - def __init__(self, command: str, cwd: Path | None = None) -> None: + def __init__(self, command: str, *, cwd: Path | None = None, check: bool = False) -> None: self.command = command self.cwd = cwd + self.check = check def encapsulate(self) -> dict: logger.debug(f"Running command: {self.command}") - output = subprocess.run(self.command, shell=True, text=True, capture_output=True, cwd=self.cwd, check=False) # noqa: S602 + output = subprocess.run( + self.command, + shell=True, # noqa: S602 + text=True, + capture_output=True, + cwd=self.cwd, + check=self.check, + ) logger.debug(f"Ran command: {self.command}. Result: {output}") return { "command": self.command, diff --git a/capsula/_context/_file.py b/capsula/_context/_file.py index daebad51..d4083878 100644 --- a/capsula/_context/_file.py +++ b/capsula/_context/_file.py @@ -1,14 +1,18 @@ from __future__ import annotations import logging +import warnings from pathlib import Path from shutil import copyfile, move -from typing import Iterable +from typing import TYPE_CHECKING, Callable, Iterable from capsula._backport import file_digest from ._base import ContextBase +if TYPE_CHECKING: + from capsula._decorator import CapsuleParams + logger = logging.getLogger(__name__) @@ -20,7 +24,7 @@ def __init__( path: Path | str, *, compute_hash: bool = True, - hash_algorithm: str | None, + hash_algorithm: str | None = None, copy_to: Iterable[Path | str] | Path | str | None = None, move_to: Path | str | None = None, ) -> None: @@ -67,3 +71,29 @@ def encapsulate(self) -> dict: def default_key(self) -> tuple[str, str]: return ("file", str(self.path)) + + @classmethod + def default( + cls, + path: Path | str, + *, + compute_hash: bool = True, + hash_algorithm: str | None = None, + copy: bool = False, + move: bool = False, + ) -> Callable[[CapsuleParams], FileContext]: + if copy and move: + warnings.warn("Both copy and move are True. Only move will be performed.", UserWarning, stacklevel=2) + move = True + copy = False + + def callback(params: CapsuleParams) -> FileContext: + return cls( + path=path, + compute_hash=compute_hash, + hash_algorithm=hash_algorithm, + copy_to=params.run_dir if copy else None, + move_to=params.run_dir if move else None, + ) + + return callback diff --git a/capsula/_context/_git.py b/capsula/_context/_git.py index 0c53216d..94ff405e 100644 --- a/capsula/_context/_git.py +++ b/capsula/_context/_git.py @@ -64,17 +64,22 @@ def default_key(self) -> tuple[str, str]: return ("git", self.name) @classmethod - def default(cls) -> Callable[[CapsuleParams], GitRepositoryContext]: + def default( + cls, + name: str | None = None, + *, + allow_dirty: bool | None = None, + ) -> Callable[[CapsuleParams], GitRepositoryContext]: def callback(params: CapsuleParams) -> GitRepositoryContext: func_file_path = Path(inspect.getfile(params.func)) repo = Repo(func_file_path.parent, search_parent_directories=True) repo_name = Path(repo.working_dir).name return cls( - name=Path(repo.working_dir).name, + name=Path(repo.working_dir).name if name is None else name, path=Path(repo.working_dir), diff_file=params.run_dir / f"{repo_name}.diff", search_parent_directories=False, - allow_dirty=True, + allow_dirty=True if allow_dirty is None else allow_dirty, ) return callback diff --git a/capsula/_decorator.py b/capsula/_decorator.py index 0583e5c3..6d782542 100644 --- a/capsula/_decorator.py +++ b/capsula/_decorator.py @@ -1,13 +1,18 @@ from __future__ import annotations +import inspect +from datetime import datetime, timezone +from pathlib import Path +from random import choices +from string import ascii_letters, digits from typing import TYPE_CHECKING, Callable, Literal, TypeVar +from capsula.utils import search_for_project_root + from ._backport import Concatenate, ParamSpec from ._run import CapsuleParams, FuncInfo, Run if TYPE_CHECKING: - from pathlib import Path - from ._capsule import Capsule from ._context import ContextBase from ._reporter import ReporterBase @@ -53,8 +58,17 @@ def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: def run( - run_dir: Path | Callable[[FuncInfo], Path], + run_dir: Path | Callable[[FuncInfo], Path] | None = None, ) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]: + def _default_run_dir_generator(func_info: FuncInfo) -> Path: + project_root = search_for_project_root(Path(inspect.getfile(func_info.func))) + random_suffix = "".join(choices(ascii_letters + digits, k=4)) # noqa: S311 + datetime_str = datetime.now(timezone.utc).astimezone().strftime(r"%Y%m%d_%H%M%S") + dir_name = f"{func_info.func.__name__}_{datetime_str}_{random_suffix}" + return project_root / "vault" / dir_name + + run_dir = _default_run_dir_generator if run_dir is None else run_dir + def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]: run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run) run.set_run_dir(run_dir) diff --git a/capsula/_reporter/_json.py b/capsula/_reporter/_json.py index 0661e17c..84131047 100644 --- a/capsula/_reporter/_json.py +++ b/capsula/_reporter/_json.py @@ -12,6 +12,7 @@ from capsula.utils import to_nested_dict if TYPE_CHECKING: + from capsula._decorator import CapsuleParams from capsula.encapsulator import Capsule from ._base import ReporterBase @@ -47,7 +48,7 @@ def __init__( self.path.parent.mkdir(parents=True, exist_ok=True) if default is None: - self.default = default_preset + self.default_for_encoder = default_preset else: def _default(obj: Any) -> Any: @@ -56,7 +57,7 @@ def _default(obj: Any) -> Any: except TypeError: return default(obj) - self.default = _default + self.default_for_encoder = _default self.option = option @@ -72,5 +73,19 @@ def _str_to_tuple(s: str | tuple[str, ...]) -> tuple[str, ...]: if capsule.fails: nested_data["__fails"] = to_nested_dict({_str_to_tuple(k): v for k, v in capsule.fails.items()}) - json_bytes = orjson.dumps(nested_data, default=self.default, option=self.option) + json_bytes = orjson.dumps(nested_data, default=self.default_for_encoder, option=self.option) self.path.write_bytes(json_bytes) + + @classmethod + def default( + cls, + *, + option: Optional[int] = None, + ) -> Callable[[CapsuleParams], JsonDumpReporter]: + def callback(params: CapsuleParams) -> JsonDumpReporter: + return cls( + params.run_dir / f"{params.phase}-run-report.json", + option=orjson.OPT_INDENT_2 if option is None else option, + ) + + return callback diff --git a/capsula/_root.py b/capsula/_root.py index bec1a441..a59017e5 100644 --- a/capsula/_root.py +++ b/capsula/_root.py @@ -1,5 +1,8 @@ +from __future__ import annotations + from typing import Any +from ._run import Run from .encapsulator import Encapsulator, _CapsuleItemKey @@ -9,3 +12,14 @@ def record(key: _CapsuleItemKey, value: Any) -> None: msg = "No active encapsulator found." raise RuntimeError(msg) enc.record(key, value) + + +def current_run_name() -> str: + run: Run | None = Run.get_current() + if run is None: + msg = "No active run found." + raise RuntimeError(msg) + if run.run_dir is None: + msg = "No active run directory found." + raise RuntimeError(msg) + return run.run_dir.name diff --git a/capsula/_run.py b/capsula/_run.py index 0942c349..1ba404b7 100644 --- a/capsula/_run.py +++ b/capsula/_run.py @@ -71,6 +71,7 @@ def __init__(self, func, *, pass_pre_run_capsule: bool = False) -> None: self.func: Callable[_P, _T] | Callable[Concatenate[Capsule, _P], _T] = func self.run_dir_generator: Callable[[FuncInfo], Path] | None = None + self.run_dir: Path | None = None def add_context( self, @@ -156,13 +157,13 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: if self.run_dir_generator is None: msg = "run_dir_generator must be set before calling the function." raise ValueError(msg) - run_dir = self.run_dir_generator(func_info) - run_dir.mkdir(parents=True, exist_ok=True) + self.run_dir = self.run_dir_generator(func_info) + self.run_dir.mkdir(parents=True, exist_ok=True) params = CapsuleParams( func=func_info.func, args=func_info.args, kwargs=func_info.kwargs, - run_dir=run_dir, + run_dir=self.run_dir, phase="pre", ) diff --git a/capsula/utils.py b/capsula/utils.py index 75257752..2ceec77e 100644 --- a/capsula/utils.py +++ b/capsula/utils.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Hashable if TYPE_CHECKING: + from pathlib import Path from types import TracebackType @@ -71,3 +72,21 @@ def to_nested_dict(flat_dict: Mapping[Sequence[Hashable], Any]) -> dict[Hashable else: nested_dict[key[0]] = to_nested_dict({key[1:]: value}) return nested_dict + + +def search_for_project_root(start: Path) -> Path: + """Search for the project root directory by looking for pyproject.toml. + + Args: + start: The start directory to search. + + Returns: + The project root directory. + + """ + if (start / "pyproject.toml").exists(): + return start + if start == start.parent: + msg = "Project root not found." + raise FileNotFoundError(msg) + return search_for_project_root(start.parent) diff --git a/coverage/badge.svg b/coverage/badge.svg index 806459b4..7ea26315 100644 --- a/coverage/badge.svg +++ b/coverage/badge.svg @@ -1 +1 @@ -coverage: 47.24%coverage47.24% \ No newline at end of file +coverage: 46.08%coverage46.08% \ No newline at end of file diff --git a/examples/decorator.py b/examples/decorator.py index 992f2300..6341ee43 100644 --- a/examples/decorator.py +++ b/examples/decorator.py @@ -1,43 +1,27 @@ import logging import random -from datetime import UTC, datetime from pathlib import Path -import orjson - import capsula logger = logging.getLogger(__name__) -@capsula.run( - run_dir=lambda _: Path(__file__).parents[1] / "vault" / datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S"), -) -@capsula.context( - lambda params: capsula.FileContext( - Path(__file__).parents[1] / "pyproject.toml", - hash_algorithm="sha256", - copy_to=params.run_dir, - ), - mode="pre", -) -@capsula.context(capsula.GitRepositoryContext.default(), mode="pre") -@capsula.reporter( - lambda params: capsula.JsonDumpReporter( - params.run_dir / f"{params.phase}-run-report.json", - option=orjson.OPT_INDENT_2, - ), - mode="all", -) +@capsula.run() +@capsula.reporter(capsula.JsonDumpReporter.default(), mode="all") +@capsula.context(capsula.FileContext.default(Path(__file__).parent / "pi.txt", move=True), mode="post") +@capsula.watcher(capsula.UncaughtExceptionWatcher("Exception")) @capsula.watcher(capsula.TimeWatcher("calculation_time")) -@capsula.context( - lambda params: capsula.FileContext( - Path(__file__).parent / "pi.txt", - hash_algorithm="sha256", - move_to=params.run_dir, - ), - mode="post", -) +@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "pyproject.toml", copy=True), mode="pre") +@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "poetry.lock", copy=True), mode="pre") +@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "requirements.txt", move=True), mode="pre") +@capsula.context(capsula.GitRepositoryContext.default("capsula"), mode="pre") +@capsula.context(capsula.CommandContext("poetry check --lock"), mode="pre") +@capsula.context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt"), mode="pre") +@capsula.context(capsula.EnvVarContext("HOME"), mode="pre") +@capsula.context(capsula.EnvVarContext("PATH"), mode="pre") +@capsula.context(capsula.CwdContext(), mode="pre") +@capsula.context(capsula.CpuContext(), mode="pre") @capsula.pass_pre_run_capsule def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None: logger.info(f"Calculating pi with {n_samples} samples.") @@ -54,6 +38,7 @@ def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, se capsula.record("pi_estimate", pi_estimate) # raise CapsulaError("This is a test error.") logger.info(pre_run_capsule.data) + logger.info(capsula.current_run_name()) with (Path(__file__).parent / "pi.txt").open("w") as output_file: output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}")