Skip to content

Commit

Permalink
Merge pull request #259 from shunichironomura/path-relative-to-root
Browse files Browse the repository at this point in the history
Add relative_to_project_root configs
  • Loading branch information
shunichironomura authored Jul 6, 2024
2 parents a5340d4 + e569597 commit 99f92e4
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 31 deletions.
12 changes: 6 additions & 6 deletions capsula.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
contexts = [
{ type = "CwdContext" },
{ type = "CpuContext" },
{ type = "GitRepositoryContext", name = "capsula", path = "." },
{ type = "CommandContext", command = "poetry check --lock" },
{ type = "FileContext", path = "pyproject.toml", copy = true },
{ type = "FileContext", path = "poetry.lock", copy = true },
{ type = "CommandContext", command = "pip freeze --exclude-editable > requirements.txt" },
{ type = "FileContext", path = "requirements.txt", move = true },
{ type = "GitRepositoryContext", name = "capsula", path = ".", path_relative_to_project_root = true },
{ type = "CommandContext", command = "poetry check --lock", cwd = ".", cwd_relative_to_project_root = true },
{ type = "FileContext", path = "pyproject.toml", copy = true, path_relative_to_project_root = true },
{ type = "FileContext", path = "poetry.lock", copy = true, path_relative_to_project_root = true },
{ type = "CommandContext", command = "pip freeze --exclude-editable > requirements.txt", cwd = ".", cwd_relative_to_project_root = true },
{ type = "FileContext", path = "requirements.txt", move = true, path_relative_to_project_root = true },
]
reporters = [{ type = "JsonDumpReporter" }]

Expand Down
9 changes: 7 additions & 2 deletions capsula/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._backport import Annotated
from ._config import load_config
from ._context import ContextBase
from ._run import CapsuleParams, generate_default_run_dir
from ._run import CapsuleParams, generate_default_run_dir, get_project_root
from .utils import get_default_config_path

app = typer.Typer()
Expand Down Expand Up @@ -43,7 +43,12 @@ def enc(
exec_info = None
run_dir = generate_default_run_dir(exec_info=exec_info)
run_dir.mkdir(exist_ok=True, parents=True)
params = CapsuleParams(exec_info=exec_info, run_dir=run_dir, phase=phase.value)
params = CapsuleParams(
exec_info=exec_info,
run_dir=run_dir,
phase=phase.value,
project_root=get_project_root(exec_info),
)

for context in contexts:
enc.add_context(context if isinstance(context, ContextBase) else context(params))
Expand Down
35 changes: 32 additions & 3 deletions capsula/_context/_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import logging
import subprocess
from typing import TYPE_CHECKING, TypedDict
from pathlib import Path
from typing import TYPE_CHECKING, Callable, TypedDict

from ._base import ContextBase

if TYPE_CHECKING:
from pathlib import Path
from capsula._run import CapsuleParams

from ._base import ContextBase

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,3 +57,30 @@ def encapsulate(self) -> _CommandContextData:

def default_key(self) -> tuple[str, str]:
return ("command", self.command)

@classmethod
def default(
cls,
command: str,
*,
cwd: Path | str | None = None,
check: bool = True,
abort_on_error: bool = True,
cwd_relative_to_project_root: bool = False,
) -> Callable[[CapsuleParams], CommandContext]:
def callback(params: CapsuleParams) -> CommandContext:
if cwd_relative_to_project_root and cwd is not None and not Path(cwd).is_absolute():
cwd_path: Path | None = params.project_root / cwd
elif cwd_relative_to_project_root and cwd is None:
cwd_path = params.project_root
else:
cwd_path = Path(cwd) if cwd is not None else None

return cls(
command,
cwd=cwd_path,
check=check,
abort_on_error=abort_on_error,
)

return callback
8 changes: 7 additions & 1 deletion capsula/_context/_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,21 @@ def default(
copy: bool = False,
move: bool = False,
ignore_missing: bool = False,
path_relative_to_project_root: bool = False,
) -> Callable[[CapsuleParams], FileContext]:
if copy and move:
warnings.warn("Both copy and move are True. Only move will be performed.", UserWarning, stacklevel=2)
move = True
copy = False

def callback(params: CapsuleParams) -> FileContext:
if path_relative_to_project_root and path is not None and not Path(path).is_absolute():
file_path = params.project_root / path
else:
file_path = Path(path)

return cls(
path=path,
path=file_path,
compute_hash=compute_hash,
hash_algorithm=hash_algorithm,
copy_to=params.run_dir if copy else None,
Expand Down
13 changes: 11 additions & 2 deletions capsula/_context/_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _GitRepositoryContextData(TypedDict):
remotes: dict[str, str]
branch: str | None
is_dirty: bool
diff_file: PathLike[str] | str | None


class GitRepositoryContext(ContextBase):
Expand Down Expand Up @@ -67,6 +68,7 @@ def get_optional_branch_name(repo: Repo) -> str | None:
"remotes": {remote.name: remote.url for remote in repo.remotes},
"branch": get_optional_branch_name(repo),
"is_dirty": repo.is_dirty(),
"diff_file": None,
}

diff_txt = repo.git.diff()
Expand All @@ -75,6 +77,7 @@ def get_optional_branch_name(repo: Repo) -> str | None:
with self.diff_file.open("w") as f:
f.write(diff_txt)
logger.debug(f"Wrote diff to {self.diff_file}")
info["diff_file"] = self.diff_file
return info

def default_key(self) -> tuple[str, str]:
Expand All @@ -86,11 +89,17 @@ def default(
name: str | None = None,
*,
path: Path | str | None = None,
path_relative_to_project_root: bool = False,
allow_dirty: bool | None = None,
) -> Callable[[CapsuleParams], GitRepositoryContext]:
def callback(params: CapsuleParams) -> GitRepositoryContext:
if path is not None:
repo = Repo(path, search_parent_directories=False)
if path_relative_to_project_root and path is not None and not Path(path).is_absolute():
repository_path: Path | None = params.project_root / path
else:
repository_path = Path(path) if path is not None else None

if repository_path is not None:
repo = Repo(repository_path, search_parent_directories=False)
else:
if isinstance(params.exec_info, FuncInfo):
repo_search_start_path = Path(inspect.getfile(params.exec_info.func)).parent
Expand Down
16 changes: 13 additions & 3 deletions capsula/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,20 @@ class CapsuleParams:
exec_info: FuncInfo | CommandInfo | None
run_dir: Path
phase: Literal["pre", "in", "post"]
project_root: Path


ExecInfo: TypeAlias = Union[FuncInfo, CommandInfo]


def generate_default_run_dir(exec_info: ExecInfo | None = None) -> Path:
exec_name: str | None
project_root = get_project_root(exec_info)
if exec_info is None:
project_root = search_for_project_root(Path.cwd())
exec_name = None
elif isinstance(exec_info, CommandInfo):
project_root = search_for_project_root(Path.cwd())
exec_name = exec_info.command.split()[0] # TODO: handle more complex commands
elif isinstance(exec_info, FuncInfo):
project_root = search_for_project_root(Path(inspect.getfile(exec_info.func)))
exec_name = exec_info.func.__name__
else:
msg = f"exec_info must be an instance of FuncInfo or CommandInfo, not {type(exec_info)}."
Expand All @@ -74,6 +73,16 @@ def generate_default_run_dir(exec_info: ExecInfo | None = None) -> Path:
return project_root / "vault" / dir_name


def get_project_root(exec_info: ExecInfo | None = None) -> Path:
if exec_info is None or isinstance(exec_info, CommandInfo):
return search_for_project_root(Path.cwd())
elif isinstance(exec_info, FuncInfo):
return search_for_project_root(Path(inspect.getfile(exec_info.func)))
else:
msg = f"exec_info must be an instance of FuncInfo or CommandInfo, not {type(exec_info)}."
raise TypeError(msg)


class Run(Generic[_P, _T]):
_thread_local = threading.local()

Expand Down Expand Up @@ -252,6 +261,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: # noqa: C901
exec_info=func_info,
run_dir=self._run_dir,
phase="pre",
project_root=get_project_root(func_info),
)

pre_run_enc = Encapsulator()
Expand Down
5 changes: 3 additions & 2 deletions capsula/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def to_nested_dict(flat_dict: Mapping[Sequence[Hashable], Any]) -> dict[Hashable
return nested_dict


def search_for_project_root(start: Path) -> Path:
def search_for_project_root(start: Path | str) -> Path:
"""Search for the project root directory by looking for pyproject.toml.
Args:
Expand All @@ -84,12 +84,13 @@ def search_for_project_root(start: Path) -> Path:
The project root directory.
"""
start = Path(start)
if (start / "pyproject.toml").exists():
return start
if start == start.parent:
msg = "Project root not found."
raise FileNotFoundError(msg)
return search_for_project_root(start.parent)
return search_for_project_root(start.resolve().parent)


def get_default_config_path() -> Path:
Expand Down
20 changes: 13 additions & 7 deletions examples/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,30 @@
from rich.logging import RichHandler

import capsula
import capsula.utils

logger = logging.getLogger(__name__)

PROJECT_ROOT = capsula.utils.search_for_project_root(__file__)


@capsula.run(ignore_config=True)
@capsula.context(capsula.EnvVarContext("HOME"), mode="pre")
@capsula.context(capsula.EnvVarContext("PATH"), mode="pre")
@capsula.context(capsula.CwdContext(), mode="pre")
@capsula.context(capsula.CpuContext(), mode="pre")
@capsula.context(capsula.GitRepositoryContext.default("capsula"), mode="pre")
@capsula.context(capsula.CommandContext("poetry check --lock"), mode="pre")
@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "pyproject.toml", copy=True), mode="pre")
@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "poetry.lock", copy=True), mode="pre")
@capsula.context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt"), mode="pre")
@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "requirements.txt", move=True), mode="pre")
@capsula.context(capsula.CommandContext("poetry check --lock", cwd=PROJECT_ROOT), mode="pre")
@capsula.context(capsula.FileContext.default(PROJECT_ROOT / "pyproject.toml", copy=True), mode="pre")
@capsula.context(capsula.FileContext.default(PROJECT_ROOT / "poetry.lock", copy=True), mode="pre")
@capsula.context(
capsula.CommandContext("pip freeze --exclude-editable > requirements.txt", cwd=PROJECT_ROOT),
mode="pre",
)
@capsula.context(capsula.FileContext.default(PROJECT_ROOT / "requirements.txt", move=True), mode="pre")
@capsula.watcher(capsula.UncaughtExceptionWatcher("Exception"))
@capsula.watcher(capsula.TimeWatcher("calculation_time"))
@capsula.context(capsula.FileContext.default(Path(__file__).parent / "pi.txt", move=True), mode="post")
@capsula.context(capsula.FileContext.default("pi.txt", move=True), mode="post")
@capsula.reporter(capsula.JsonDumpReporter.default(), mode="all")
@capsula.pass_pre_run_capsule
def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, seed: int = 42) -> None:
Expand All @@ -41,7 +47,7 @@ def calculate_pi(pre_run_capsule: capsula.Capsule, *, n_samples: int = 1_000, se
# raise CapsulaError("This is a test error.")
logger.info(f"Run name: {capsula.current_run_name()}")

with (Path(__file__).parent / "pi.txt").open("w") as output_file:
with (Path("pi.txt")).open("w") as output_file:
output_file.write(f"Pi estimate: {pi_estimate}. Git SHA: {pre_run_capsule.data[('git', 'capsula')]['sha']}")


Expand Down
4 changes: 2 additions & 2 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
@capsula.context(capsula.CwdContext(), mode="pre")
@capsula.context(capsula.EnvVarContext("PATH"), mode="pre")
@capsula.context(capsula.EnvVarContext("HOME"), mode="pre")
@capsula.context(capsula.CommandContext("pip freeze --exclude-editable > requirements.txt"), mode="pre")
@capsula.context(capsula.CommandContext("poetry check --lock"), mode="pre")
@capsula.context(capsula.CommandContext.default("pip freeze --exclude-editable > requirements.txt"), mode="pre")
@capsula.context(capsula.CommandContext.default("poetry check --lock"), mode="pre")
@capsula.context(capsula.GitRepositoryContext.default("capsula"), mode="pre")
@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "requirements.txt", move=True), mode="pre")
@capsula.context(capsula.FileContext.default(Path(__file__).parents[1] / "poetry.lock", copy=True), mode="pre")
Expand Down
4 changes: 1 addition & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def test_search_for_project_root_in_parent() -> None:
child_dir.mkdir()

# Test that the search correctly finds the project root in the parent directory
assert (
search_for_project_root(child_dir) == project_root
), "Failed to find the project root in the parent directory"
assert search_for_project_root(child_dir).resolve() == project_root.resolve()


def test_search_for_project_root_not_found() -> None:
Expand Down

0 comments on commit 99f92e4

Please sign in to comment.