Skip to content

Commit

Permalink
Merge pull request #163 from shunichironomura/pass-pre-run-capsule
Browse files Browse the repository at this point in the history
Add `pass_pre_run_capsule` decorator
  • Loading branch information
shunichironomura authored Feb 3, 2024
2 parents eeacb3f + 9659e05 commit b78efff
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 23 deletions.
5 changes: 4 additions & 1 deletion capsula/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"CapsulaConfigurationError",
"CapsulaError",
"Capsule",
"CommandContext",
"ContextBase",
"CpuContext",
Expand All @@ -21,13 +22,15 @@
"get_capsule_dir",
"get_capsule_name",
"monitor",
"pass_pre_run_capsule",
"record",
"reporter",
"run",
"set_capsule_dir",
"set_capsule_name",
"watcher",
]
from ._capsule import Capsule
from ._context import (
CommandContext,
ContextBase,
Expand All @@ -38,7 +41,7 @@
GitRepositoryContext,
PlatformContext,
)
from ._decorator import context, reporter, run, watcher
from ._decorator import context, pass_pre_run_capsule, reporter, run, watcher
from ._reporter import JsonDumpReporter, ReporterBase
from ._root import record
from ._run import Run
Expand Down
6 changes: 3 additions & 3 deletions capsula/_backport.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

__all__ = ["ParamSpec", "Self", "TypeAlias", "file_digest"]
__all__ = ["Concatenate", "ParamSpec", "Self", "TypeAlias", "file_digest"]

import hashlib
import sys
Expand All @@ -12,9 +12,9 @@
from typing_extensions import Self

if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeAlias
from typing import Concatenate, ParamSpec, TypeAlias
else:
from typing_extensions import ParamSpec, TypeAlias
from typing_extensions import Concatenate, ParamSpec, TypeAlias


if sys.version_info >= (3, 11):
Expand Down
22 changes: 11 additions & 11 deletions capsula/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from typing import TYPE_CHECKING, Callable, Literal, TypeVar

from ._backport import ParamSpec
from ._backport import Concatenate, ParamSpec
from ._run import CapsuleParams, FuncInfo, Run

if TYPE_CHECKING:
from pathlib import Path

from capsula._reporter import ReporterBase

from ._capsule import Capsule
from ._context import ContextBase
from ._reporter import ReporterBase
from ._watcher import WatcherBase

_P = ParamSpec("_P")
Expand All @@ -21,8 +21,7 @@ def watcher(
watcher: WatcherBase | Callable[[CapsuleParams], WatcherBase],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
run = func_or_run if isinstance(func_or_run, Run) else Run(func)
run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.add_watcher(watcher)
return run

Expand All @@ -34,8 +33,7 @@ def reporter(
mode: Literal["pre", "in", "post", "all"],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
run = func_or_run if isinstance(func_or_run, Run) else Run(func)
run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.add_reporter(reporter, mode=mode)
return run

Expand All @@ -47,8 +45,7 @@ def context(
mode: Literal["pre", "post", "all"],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
run = func_or_run if isinstance(func_or_run, Run) else Run(func)
run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.add_context(context, mode=mode)
return run

Expand All @@ -59,10 +56,13 @@ def run(
run_dir: Path | Callable[[FuncInfo], Path],
) -> Callable[[Callable[_P, _T] | Run[_P, _T]], Run[_P, _T]]:
def decorator(func_or_run: Callable[_P, _T] | Run[_P, _T]) -> Run[_P, _T]:
func = func_or_run.func if isinstance(func_or_run, Run) else func_or_run
run = func_or_run if isinstance(func_or_run, Run) else Run(func)
run = func_or_run if isinstance(func_or_run, Run) else Run(func_or_run)
run.set_run_dir(run_dir)

return run

return decorator


def pass_pre_run_capsule(func: Callable[Concatenate[Capsule, _P], _T]) -> Run[_P, _T]:
return Run(func, pass_pre_run_capsule=True)
25 changes: 20 additions & 5 deletions capsula/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
import queue
import threading
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar
from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar, overload

from pydantic import BaseModel

from capsula._reporter import ReporterBase
from capsula.encapsulator import Encapsulator

from ._backport import ParamSpec, Self
from ._backport import Concatenate, ParamSpec, Self
from ._context import ContextBase
from ._watcher import WatcherBase

if TYPE_CHECKING:
from types import TracebackType

from ._capsule import Capsule

_P = ParamSpec("_P")
_T = TypeVar("_T")

Expand Down Expand Up @@ -48,7 +50,15 @@ def get_current(cls) -> Self | None:
except IndexError:
return None

def __init__(self, func: Callable[_P, _T]) -> None:
@overload
def __init__(self, func: Callable[_P, _T], *, pass_pre_run_capsule: Literal[False] = False) -> None:
...

@overload
def __init__(self, func: Callable[Concatenate[Capsule, _P], _T], *, pass_pre_run_capsule: Literal[True]) -> None:
...

def __init__(self, func, *, pass_pre_run_capsule: bool = False) -> None:
self.pre_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = []
self.in_run_watcher_generators: list[Callable[[CapsuleParams], WatcherBase]] = []
self.post_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = []
Expand All @@ -57,7 +67,9 @@ def __init__(self, func: Callable[_P, _T]) -> None:
self.in_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = []
self.post_run_reporter_generators: list[Callable[[CapsuleParams], ReporterBase]] = []

self.func = func
self.pass_pre_run_capsule: bool = pass_pre_run_capsule
self.func: Callable[_P, _T] | Callable[Concatenate[Capsule, _P], _T] = func

self.run_dir_generator: Callable[[FuncInfo], Path] | None = None

def add_context(
Expand Down Expand Up @@ -170,7 +182,10 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
in_run_enc.add_watcher(watcher)

with self, in_run_enc, in_run_enc.watch():
result = self.func(*args, **kwargs)
if self.pass_pre_run_capsule:
result = self.func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type]
else:
result = self.func(*args, **kwargs)

in_run_capsule = in_run_enc.encapsulate()
for reporter_generator in self.in_run_reporter_generators:
Expand Down
2 changes: 1 addition & 1 deletion coverage/badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions examples/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
),
mode="post",
)
def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
@capsula.pass_pre_run_capsule
def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None:
logger.info(f"Calculating pi with {n_samples} samples.")
logger.debug(f"Seed: {seed}")
random.seed(seed)
Expand All @@ -52,9 +53,10 @@ def calculate_pi(*, n_samples: int = 1_000, seed: int = 42) -> None:
logger.info(f"Pi estimate: {pi_estimate}")
capsula.record("pi_estimate", pi_estimate)
# raise CapsulaError("This is a test error.")
logger.info(pre_run_capsule.data)

with (Path(__file__).parent / "pi.txt").open("w") as output_file:
output_file.write(str(pi_estimate))
output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}")


if __name__ == "__main__":
Expand Down

0 comments on commit b78efff

Please sign in to comment.