Skip to content

Commit

Permalink
Merge pull request #148 from shunichironomura/decorator-v2
Browse files Browse the repository at this point in the history
Implement a new capsula.run decorator
  • Loading branch information
shunichironomura authored Jan 28, 2024
2 parents a08fec2 + 983beb3 commit c039775
Show file tree
Hide file tree
Showing 32 changed files with 628 additions and 100 deletions.
35 changes: 35 additions & 0 deletions capsula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,44 @@
"set_capsule_dir",
"set_capsule_name",
"Encapsulator",
"capsule",
"record",
"Run",
"ContextBase",
"CwdContext",
"EnvVarContext",
"GitRepositoryContext",
"FileContext",
"PlatformContext",
"CpuContext",
"CommandContext",
"JsonDumpReporter",
"ReporterBase",
"WatcherBase",
"TimeWatcher",
"watcher",
"reporter",
"context",
"UncaughtExceptionWatcher",
"run",
]
from ._context import (
CommandContext,
ContextBase,
CpuContext,
CwdContext,
EnvVarContext,
FileContext,
GitRepositoryContext,
PlatformContext,
)
from ._decorator import capsule, context, reporter, run, watcher
from ._monitor import monitor
from ._reporter import JsonDumpReporter, ReporterBase
from ._root import record
from ._run import Run
from ._version import __version__
from ._watcher import TimeWatcher, UncaughtExceptionWatcher, WatcherBase
from .encapsulator import Encapsulator
from .exceptions import CapsulaConfigurationError, CapsulaError
from .globalvars import get_capsule_dir, get_capsule_name, set_capsule_dir, set_capsule_name
10 changes: 3 additions & 7 deletions capsula/_backport.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from __future__ import annotations

__all__ = [
"TypeAlias",
"file_digest",
"Self",
]
__all__ = ["TypeAlias", "file_digest", "Self", "ParamSpec"]

import hashlib
import sys
Expand All @@ -16,9 +12,9 @@
from typing_extensions import Self

if sys.version_info >= (3, 10):
from typing import TypeAlias
from typing import ParamSpec, TypeAlias
else:
from typing_extensions import TypeAlias
from typing_extensions import ParamSpec, TypeAlias


if sys.version_info >= (3, 11):
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions capsula/context/__init__.py → capsula/_context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = [
"Context",
"ContextBase",
"CwdContext",
"EnvVarContext",
"GitRepositoryContext",
Expand All @@ -8,7 +8,7 @@
"CpuContext",
"CommandContext",
]
from ._base import Context
from ._base import ContextBase
from ._command import CommandContext
from ._cpu import CpuContext
from ._cwd import CwdContext
Expand Down
5 changes: 5 additions & 0 deletions capsula/_context/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from capsula._capsule import CapsuleItem


class ContextBase(CapsuleItem):
pass
4 changes: 2 additions & 2 deletions capsula/context/_command.py → capsula/_context/_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
if TYPE_CHECKING:
from pathlib import Path

from ._base import Context
from ._base import ContextBase

logger = logging.getLogger(__name__)


class CommandContext(Context):
class CommandContext(ContextBase):
def __init__(self, command: str, cwd: Path | None = None) -> None:
self.command = command
self.cwd = cwd
Expand Down
4 changes: 2 additions & 2 deletions capsula/context/_cpu.py → capsula/_context/_cpu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from cpuinfo import get_cpu_info

from ._base import Context
from ._base import ContextBase


class CpuContext(Context):
class CpuContext(ContextBase):
def encapsulate(self) -> dict:
return get_cpu_info()

Expand Down
4 changes: 2 additions & 2 deletions capsula/context/_cwd.py → capsula/_context/_cwd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pathlib import Path

from ._base import Context
from ._base import ContextBase


class CwdContext(Context):
class CwdContext(ContextBase):
def encapsulate(self) -> Path:
return Path.cwd()

Expand Down
4 changes: 2 additions & 2 deletions capsula/context/_envvar.py → capsula/_context/_envvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import os

from ._base import Context
from ._base import ContextBase


class EnvVarContext(Context):
class EnvVarContext(ContextBase):
def __init__(self, name: str) -> None:
self.name = name

Expand Down
18 changes: 9 additions & 9 deletions capsula/context/_file.py → capsula/_context/_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from capsula._backport import file_digest

from ._base import Context
from ._base import ContextBase

logger = logging.getLogger(__name__)


class FileContext(Context):
class FileContext(ContextBase):
def __init__(
self,
path: Path | str,
Expand All @@ -32,13 +32,11 @@ def __init__(
else:
self.copy_to = tuple(Path(p) for p in copy_to)

def normalize_copy_dst_path(p: Path) -> Path:
if p.is_dir():
return p / self.path.name
else:
return p

self.copy_to = tuple(normalize_copy_dst_path(p) for p in self.copy_to)
def _normalize_copy_dst_path(self, p: Path) -> Path:
if p.is_dir():
return p / self.path.name
else:
return p

def encapsulate(self) -> dict:
if self.hash_algorithm is None:
Expand All @@ -47,6 +45,8 @@ def encapsulate(self) -> dict:
with self.path.open("rb") as f:
digest = file_digest(f, self.hash_algorithm).hexdigest()

self.copy_to = tuple(self._normalize_copy_dst_path(p) for p in self.copy_to)

info: dict = {
"hash": {
"algorithm": self.hash_algorithm,
Expand Down
25 changes: 23 additions & 2 deletions capsula/context/_git.py → capsula/_context/_git.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

import inspect
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Callable

from git.repo import Repo

from capsula.exceptions import CapsulaError

from ._base import Context
from ._base import ContextBase

if TYPE_CHECKING:
from capsula._decorator import CapsuleParams

logger = logging.getLogger(__name__)

Expand All @@ -18,7 +23,7 @@ def __init__(self, repo: Repo) -> None:
super().__init__(f"Repository {repo.working_dir} is dirty")


class GitRepositoryContext(Context):
class GitRepositoryContext(ContextBase):
def __init__(
self,
name: str,
Expand Down Expand Up @@ -57,3 +62,19 @@ def encapsulate(self) -> dict:

def default_key(self) -> tuple[str, str]:
return ("git", self.name)

@classmethod
def default(cls) -> Callable[[CapsuleParams], GitRepositoryContext]:
def callback(params: CapsuleParams) -> GitRepositoryContext:
func_file_path = Path(inspect.getfile(params.func))
repo = Repo(func_file_path.parent, search_parent_directories=True)
repo_name = Path(repo.working_dir).name
return cls(
name=Path(repo.working_dir).name,
path=Path(repo.working_dir),
diff_file=params.run_dir / f"{repo_name}.diff",
search_parent_directories=False,
allow_dirty=True,
)

return callback
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import platform as pf

from ._base import Context
from ._base import ContextBase


class PlatformContext(Context):
class PlatformContext(ContextBase):
def encapsulate(self) -> dict:
return {
"machine": pf.machine(),
Expand Down
132 changes: 132 additions & 0 deletions capsula/_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Literal, Tuple, TypeVar, Union

from capsula._reporter import ReporterBase
from capsula.encapsulator import Encapsulator

from ._backport import ParamSpec
from ._context import ContextBase
from ._run import CapsuleParams, FuncInfo, Run
from ._watcher import WatcherBase

if TYPE_CHECKING:
from collections.abc import Sequence

from ._backport import TypeAlias

_P = ParamSpec("_P")
_T = TypeVar("_T")


_ContextInput: TypeAlias = Union[
ContextBase,
Tuple[ContextBase, Tuple[str, ...]],
Callable[[Path, Callable], Union[ContextBase, Tuple[ContextBase, Tuple[str, ...]]]],
]
_WatcherInput: TypeAlias = Union[
WatcherBase,
Tuple[WatcherBase, Tuple[str, ...]],
Callable[[Path, Callable], Union[WatcherBase, Tuple[WatcherBase, Tuple[str, ...]]]],
]
_ReporterInput: TypeAlias = Union[ReporterBase, Callable[[Path, Callable], ReporterBase]]


def capsule( # noqa: C901
capsule_directory: Path | str | None = None,
pre_run_contexts: Sequence[_ContextInput] | None = None,
pre_run_reporters: Sequence[_ReporterInput] | None = None,
in_run_watchers: Sequence[_WatcherInput] | None = None,
post_run_contexts: Sequence[_ContextInput] | None = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
if capsule_directory is None:
raise NotImplementedError
capsule_directory = Path(capsule_directory)

assert pre_run_contexts is not None
assert pre_run_reporters is not None
assert in_run_watchers is not None
assert post_run_contexts is not None

def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
pre_run_enc = Encapsulator()
for cxt in pre_run_contexts:
if isinstance(cxt, ContextBase):
pre_run_enc.add_context(cxt)
elif isinstance(cxt, tuple):
pre_run_enc.add_context(cxt[0], key=cxt[1])
else:
cxt_hydrated = cxt(capsule_directory, func)
if isinstance(cxt_hydrated, ContextBase):
pre_run_enc.add_context(cxt_hydrated)
elif isinstance(cxt_hydrated, tuple):
pre_run_enc.add_context(cxt_hydrated[0], key=cxt_hydrated[1])

@wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
capsule_directory.mkdir(parents=True, exist_ok=True)
pre_run_capsule = pre_run_enc.encapsulate()
for reporter in pre_run_reporters:
if isinstance(reporter, ReporterBase):
reporter.report(pre_run_capsule)
else:
reporter(capsule_directory, func).report(pre_run_capsule)

return func(*args, **kwargs)

return wrapper

return decorator


def watcher(
watcher: WatcherBase | Callable[[CapsuleParams], WatcherBase],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
run = func_or_run if isinstance(func_or_run, Run) else Run(func)
run.add_watcher(watcher)
return run

return decorator


def reporter(
reporter: ReporterBase | Callable[[CapsuleParams], ReporterBase],
mode: Literal["pre", "in", "post", "all"],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
run = func_or_run if isinstance(func_or_run, Run) else Run(func)
run.add_reporter(reporter, mode=mode)
return run

return decorator


def context(
context: ContextBase | Callable[[CapsuleParams], ContextBase],
mode: Literal["pre", "post", "all"],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
run = func_or_run if isinstance(func_or_run, Run) else Run(func)
run.add_context(context, mode=mode)
return run

return decorator


def run(
run_dir: Path | Callable[[FuncInfo], Path],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
run = func_or_run if isinstance(func_or_run, Run) else Run(func)
run.set_run_dir(run_dir)

return run

return decorator
3 changes: 3 additions & 0 deletions capsula/_reporter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__all__ = ["JsonDumpReporter", "ReporterBase"]
from ._base import ReporterBase
from ._json import JsonDumpReporter
2 changes: 1 addition & 1 deletion capsula/reporter/_base.py → capsula/_reporter/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from capsula.encapsulator import Capsule


class Reporter(ABC):
class ReporterBase(ABC):
@abstractmethod
def report(self, capsule: Capsule) -> None:
raise NotImplementedError
Loading

0 comments on commit c039775

Please sign in to comment.