Skip to content

Commit

Permalink
Merge pull request #164 from shunichironomura/make-default
Browse files Browse the repository at this point in the history
Make default methods for capsule items
  • Loading branch information
shunichironomura authored Feb 3, 2024
2 parents b78efff + 0f1e4eb commit 553fab7
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 83 deletions.
52 changes: 19 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,30 @@ See the following Python script:
```python
import logging
import random
from datetime import UTC, datetime
from pathlib import Path

import orjson

import capsula

logger = logging.getLogger(__name__)


@capsula.run(
run_dir=lambda _: Path(__file__).parents[1] / "vault" / datetime.now(UTC).astimezone().strftime(r"%Y%m%d_%H%M%S"),
)
@capsula.context(
lambda params: capsula.FileContext(
Path(__file__).parents[1] / "pyproject.toml",
hash_algorithm="sha256",
copy_to=params.run_dir,
),
mode="pre",
)
@capsula.context(capsula.GitRepositoryContext.default(), mode="pre")
@capsula.reporter(
lambda params: capsula.JsonDumpReporter(
params.run_dir / f"{params.phase}-run-report.json",
option=orjson.OPT_INDENT_2,
),
mode="all",
)
@capsula.run()
@capsula.reporter(capsula.JsonDumpReporter.default(), mode="all")
@capsula.context(capsula.FileContext.default(Path(__file__).parent / "pi.txt", move=True), mode="post")
@capsula.watcher(capsula.UncaughtExceptionWatcher("Exception"))
@capsula.watcher(capsula.TimeWatcher("calculation_time"))
@capsula.context(
lambda params: capsula.FileContext(
Path(__file__).parent / "pi.txt",
hash_algorithm="sha256",
move_to=params.run_dir,
),
mode="post",
)
def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "pyproject.toml", copy=True), mode="pre")
@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "poetry.lock", copy=True), mode="pre")
@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "requirements.txt", move=True), mode="pre")
@capsula.context(capsula.GitRepositoryContext.default(), mode="pre")
@capsula.context(capsula.CommandContext("poetry check --lock"), mode="pre")
@capsula.context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt"), mode="pre")
@capsula.context(capsula.EnvVarContext("HOME"), mode="pre")
@capsula.context(capsula.EnvVarContext("PATH"), mode="pre")
@capsula.context(capsula.CwdContext(), mode="pre")
@capsula.context(capsula.CpuContext(), mode="pre")
@capsula.pass_pre_run_capsule
def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None:
logger.info(f"Calculating pi with {n_samples} samples.")
logger.debug(f"Seed: {seed}")
random.seed(seed)
Expand All @@ -76,13 +61,14 @@ def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
pi_estimate = (4.0 * inside) / n_samples
logger.info(f"Pi estimate: {pi_estimate}")
capsula.record("pi_estimate", pi_estimate)
logger.info(pre_run_capsule.data)
logger.info(capsula.current_run_name())

with (Path(__file__).parent / "pi.txt").open("w") as output_file:
output_file.write(str(pi_estimate))
output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
calculate_pi(n_samples=1_000)
```

Expand Down
3 changes: 2 additions & 1 deletion capsula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"WatcherBase",
"__version__",
"context",
"current_run_name",
"get_capsule_dir",
"get_capsule_name",
"monitor",
Expand All @@ -43,7 +44,7 @@
)
from ._decorator import context, pass_pre_run_capsule, reporter, run, watcher
from ._reporter import JsonDumpReporter, ReporterBase
from ._root import record
from ._root import current_run_name, record
from ._run import Run
from ._version import __version__
from ._watcher import TimeWatcher, UncaughtExceptionWatcher, WatcherBase
Expand Down
12 changes: 10 additions & 2 deletions capsula/_capsule.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Tuple, Union

if TYPE_CHECKING:
from collections.abc import Mapping

from capsula._decorator import CapsuleParams
from capsula.utils import ExceptionInfo

from ._backport import TypeAlias
from ._backport import Self, TypeAlias


_ContextKey: TypeAlias = Union[str, Tuple[str, ...]]
Expand All @@ -32,3 +33,10 @@ def encapsulate(self) -> Any:
def default_key(self) -> str | tuple[str, ...]:
msg = f"{self.__class__.__name__}.default_key() is not implemented"
raise NotImplementedError(msg)

@classmethod
def default(cls, *args: Any, **kwargs: Any) -> Callable[[CapsuleParams], Self]:
def callback(params: CapsuleParams) -> Self: # type: ignore[type-var,misc] # noqa: ARG001
return cls(*args, **kwargs)

return callback
12 changes: 10 additions & 2 deletions capsula/_context/_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@


class CommandContext(ContextBase):
def __init__(self, command: str, cwd: Path | None = None) -> None:
def __init__(self, command: str, *, cwd: Path | None = None, check: bool = False) -> None:
self.command = command
self.cwd = cwd
self.check = check

def encapsulate(self) -> dict:
logger.debug(f"Running command: {self.command}")
output = subprocess.run(self.command, shell=True, text=True, capture_output=True, cwd=self.cwd, check=False) # noqa: S602
output = subprocess.run(
self.command,
shell=True, # noqa: S602
text=True,
capture_output=True,
cwd=self.cwd,
check=self.check,
)
logger.debug(f"Ran command: {self.command}. Result: {output}")
return {
"command": self.command,
Expand Down
34 changes: 32 additions & 2 deletions capsula/_context/_file.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import logging
import warnings
from pathlib import Path
from shutil import copyfile, move
from typing import Iterable
from typing import TYPE_CHECKING, Callable, Iterable

from capsula._backport import file_digest

from ._base import ContextBase

if TYPE_CHECKING:
from capsula._decorator import CapsuleParams

logger = logging.getLogger(__name__)


Expand All @@ -20,7 +24,7 @@ def __init__(
path: Path | str,
*,
compute_hash: bool = True,
hash_algorithm: str | None,
hash_algorithm: str | None = None,
copy_to: Iterable[Path | str] | Path | str | None = None,
move_to: Path | str | None = None,
) -> None:
Expand Down Expand Up @@ -67,3 +71,29 @@ def encapsulate(self) -> dict:

def default_key(self) -> tuple[str, str]:
return ("file", str(self.path))

@classmethod
def default(
cls,
path: Path | str,
*,
compute_hash: bool = True,
hash_algorithm: str | None = None,
copy: bool = False,
move: bool = False,
) -> Callable[[CapsuleParams], FileContext]:
if copy and move:
warnings.warn("Both copy and move are True. Only move will be performed.", UserWarning, stacklevel=2)
move = True
copy = False

def callback(params: CapsuleParams) -> FileContext:
return cls(
path=path,
compute_hash=compute_hash,
hash_algorithm=hash_algorithm,
copy_to=params.run_dir if copy else None,
move_to=params.run_dir if move else None,
)

return callback
11 changes: 8 additions & 3 deletions capsula/_context/_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,22 @@ def default_key(self) -> tuple[str, str]:
return ("git", self.name)

@classmethod
def default(cls) -> Callable[[CapsuleParams], GitRepositoryContext]:
def default(
cls,
name: str | None = None,
*,
allow_dirty: bool | None = None,
) -> Callable[[CapsuleParams], GitRepositoryContext]:
def callback(params: CapsuleParams) -> GitRepositoryContext:
func_file_path = Path(inspect.getfile(params.func))
repo = Repo(func_file_path.parent, search_parent_directories=True)
repo_name = Path(repo.working_dir).name
return cls(
name=Path(repo.working_dir).name,
name=Path(repo.working_dir).name if name is None else name,
path=Path(repo.working_dir),
diff_file=params.run_dir / f"{repo_name}.diff",
search_parent_directories=False,
allow_dirty=True,
allow_dirty=True if allow_dirty is None else allow_dirty,
)

return callback
20 changes: 17 additions & 3 deletions capsula/_decorator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

import inspect
from datetime import datetime, timezone
from pathlib import Path
from random import choices
from string import ascii_letters, digits
from typing import TYPE_CHECKING, Callable, Literal, TypeVar

from capsula.utils import search_for_project_root

from ._backport import Concatenate, ParamSpec
from ._run import CapsuleParams, FuncInfo, Run

if TYPE_CHECKING:
from pathlib import Path

from ._capsule import Capsule
from ._context import ContextBase
from ._reporter import ReporterBase
Expand Down Expand Up @@ -53,8 +58,17 @@ def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:


def run(
run_dir: Path | Callable[[FuncInfo], Path],
run_dir: Path | Callable[[FuncInfo], Path] | None = None,
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def _default_run_dir_generator(func_info: FuncInfo) -> Path:
project_root = search_for_project_root(Path(inspect.getfile(func_info.func)))
random_suffix = "".join(choices(ascii_letters + digits, k=4)) # noqa: S311
datetime_str = datetime.now(timezone.utc).astimezone().strftime(r"%Y%m%d_%H%M%S")
dir_name = f"{func_info.func.__name__}_{datetime_str}_{random_suffix}"
return project_root / "vault" / dir_name

run_dir = _default_run_dir_generator if run_dir is None else run_dir

def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.set_run_dir(run_dir)
Expand Down
21 changes: 18 additions & 3 deletions capsula/_reporter/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from capsula.utils import to_nested_dict

if TYPE_CHECKING:
from capsula._decorator import CapsuleParams
from capsula.encapsulator import Capsule

from ._base import ReporterBase
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
self.path.parent.mkdir(parents=True, exist_ok=True)

if default is None:
self.default = default_preset
self.default_for_encoder = default_preset
else:

def _default(obj: Any) -> Any:
Expand All @@ -56,7 +57,7 @@ def _default(obj: Any) -> Any:
except TypeError:
return default(obj)

self.default = _default
self.default_for_encoder = _default

self.option = option

Expand All @@ -72,5 +73,19 @@ def _str_to_tuple(s: str | tuple[str, ...]) -> tuple[str, ...]:
if capsule.fails:
nested_data["__fails"] = to_nested_dict({_str_to_tuple(k): v for k, v in capsule.fails.items()})

json_bytes = orjson.dumps(nested_data, default=self.default, option=self.option)
json_bytes = orjson.dumps(nested_data, default=self.default_for_encoder, option=self.option)
self.path.write_bytes(json_bytes)

@classmethod
def default(
cls,
*,
option: Optional[int] = None,
) -> Callable[[CapsuleParams], JsonDumpReporter]:
def callback(params: CapsuleParams) -> JsonDumpReporter:
return cls(
params.run_dir / f"{params.phase}-run-report.json",
option=orjson.OPT_INDENT_2 if option is None else option,
)

return callback
14 changes: 14 additions & 0 deletions capsula/_root.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

from typing import Any

from ._run import Run
from .encapsulator import Encapsulator, _CapsuleItemKey


Expand All @@ -9,3 +12,14 @@ def record(key: _CapsuleItemKey, value: Any) -> None:
msg = "No active encapsulator found."
raise RuntimeError(msg)
enc.record(key, value)


def current_run_name() -> str:
run: Run | None = Run.get_current()
if run is None:
msg = "No active run found."
raise RuntimeError(msg)
if run.run_dir is None:
msg = "No active run directory found."
raise RuntimeError(msg)
return run.run_dir.name
7 changes: 4 additions & 3 deletions capsula/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, func, *, pass_pre_run_capsule: bool = False) -> None:
self.func: Callable[_P, _T] | Callable[Concatenate[Capsule, _P], _T] = func

self.run_dir_generator: Callable[[FuncInfo], Path] | None = None
self.run_dir: Path | None = None

def add_context(
self,
Expand Down Expand Up @@ -156,13 +157,13 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
if self.run_dir_generator is None:
msg = "run_dir_generator must be set before calling the function."
raise ValueError(msg)
run_dir = self.run_dir_generator(func_info)
run_dir.mkdir(parents=True, exist_ok=True)
self.run_dir = self.run_dir_generator(func_info)
self.run_dir.mkdir(parents=True, exist_ok=True)
params = CapsuleParams(
func=func_info.func,
args=func_info.args,
kwargs=func_info.kwargs,
run_dir=run_dir,
run_dir=self.run_dir,
phase="pre",
)

Expand Down
19 changes: 19 additions & 0 deletions capsula/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any, Hashable

if TYPE_CHECKING:
from pathlib import Path
from types import TracebackType


Expand Down Expand Up @@ -71,3 +72,21 @@ def to_nested_dict(flat_dict: Mapping[Sequence[Hashable], Any]) -> dict[Hashable
else:
nested_dict[key[0]] = to_nested_dict({key[1:]: value})
return nested_dict


def search_for_project_root(start: Path) -> Path:
"""Search for the project root directory by looking for pyproject.toml.
Args:
start: The start directory to search.
Returns:
The project root directory.
"""
if (start / "pyproject.toml").exists():
return start
if start == start.parent:
msg = "Project root not found."
raise FileNotFoundError(msg)
return search_for_project_root(start.parent)
Loading

0 comments on commit 553fab7

Please sign in to comment.