Skip to content

Commit

Permalink
Merge pull request #347 from shunichironomura/run-command
Browse files Browse the repository at this point in the history
run-command
  • Loading branch information
shunichironomura authored Sep 23, 2024
2 parents 9ec252a + f367fb6 commit ec48ea8
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 43 deletions.
88 changes: 83 additions & 5 deletions capsula/_cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from __future__ import annotations

import logging
import shlex
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from random import choices
from string import ascii_letters, digits
from typing import Literal, NoReturn
from typing import Any, List, Literal, NoReturn, Optional

import typer
from rich.console import Console

import capsula

Expand All @@ -13,20 +19,92 @@
from ._context import ContextBase
from ._run import (
CapsuleParams,
Run,
RunDtoCommand,
default_run_name_factory,
get_default_vault_dir,
get_project_root,
)
from ._utils import get_default_config_path
from ._utils import get_default_config_path, search_for_project_root

logger = logging.getLogger(__name__)

app = typer.Typer()
console = Console()
err_console = Console(stderr=True)


@app.command()
def run() -> NoReturn:
typer.echo("Running...")
def run(
command: Annotated[List[str], typer.Argument(help="Command to run", show_default=False)],
run_name: Annotated[
Optional[str],
typer.Option(
...,
help="Run name. Make sure it is unique. If not provided, it will be generated randomly.",
),
] = None,
vault_dir: Annotated[
Optional[Path],
typer.Option(
...,
help="Vault directory. If not provided, it will be set to the default value.",
),
] = None,
ignore_config: Annotated[
bool,
typer.Option(
...,
help="Ignore the configuration file and run the command directly.",
),
] = False,
config_path: Annotated[
Optional[Path],
typer.Option(
...,
help="Path to the Capsula configuration file.",
),
] = None,
) -> NoReturn:
err_console.print(f"Running command '{shlex.join(command)}'...")
run_dto = RunDtoCommand(
run_name_factory=default_run_name_factory if run_name is None else lambda _x, _y, _z: run_name,
vault_dir=vault_dir,
command=tuple(command),
)

raise typer.Exit
if not ignore_config:
config = load_config(get_default_config_path() if config_path is None else config_path)
for phase in ("pre", "in", "post"):
phase_key = f"{phase}-run"
if phase_key not in config:
continue
for context in reversed(config[phase_key].get("contexts", [])): # type: ignore[literal-required]
assert phase in {"pre", "post"}, f"Invalid phase for context: {phase}"
run_dto.add_context(context, mode=phase, append_left=True) # type: ignore[arg-type]
for watcher in reversed(config[phase_key].get("watchers", [])): # type: ignore[literal-required]
assert phase == "in", "Watcher can only be added to the in-run phase."
# No need to set append_left=True here, as watchers are added as the outermost context manager
run_dto.add_watcher(watcher, append_left=False)
for reporter in reversed(config[phase_key].get("reporters", [])): # type: ignore[literal-required]
assert phase in {"pre", "in", "post"}, f"Invalid phase for reporter: {phase}"
run_dto.add_reporter(reporter, mode=phase, append_left=True) # type: ignore[arg-type]

run_dto.vault_dir = config["vault-dir"] if run_dto.vault_dir is None else run_dto.vault_dir

# Set the vault directory if it is not set by the config file
if run_dto.vault_dir is None:
project_root = search_for_project_root(Path.cwd())
run_dto.vault_dir = project_root / "vault"

run: Run[Any, Any] = Run(run_dto)
result, params = run.exec_command()
console.print(result.stdout, end="")
err_console.print(result.stderr, end="")
err_console.print(f"Run directory: {params.run_dir}")
err_console.print(f"Command exited with code {result.returncode}")

raise typer.Exit(result.returncode)


class _PhaseForEncapsulate(str, Enum):
Expand Down
138 changes: 100 additions & 38 deletions capsula/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import logging
import queue
import subprocess
import threading
from collections import deque
from dataclasses import dataclass, field
Expand All @@ -29,6 +30,8 @@

P = ParamSpec("P")
T = TypeVar("T")
_P = ParamSpec("_P")
_T = TypeVar("_T")

logger = logging.getLogger(__name__)

Expand All @@ -50,7 +53,7 @@ def bound_args(self) -> OrderedDict[str, Any]:

@dataclass
class CommandInfo:
command: str
command: tuple[str, ...]


@dataclass
Expand All @@ -75,7 +78,7 @@ def default_run_name_factory(exec_info: ExecInfo | None, random_str: str, timest
if exec_info is None:
exec_name = None
elif isinstance(exec_info, CommandInfo):
exec_name = exec_info.command.split()[0] # TODO: handle more complex commands
exec_name = exec_info.command[0]
elif isinstance(exec_info, FuncInfo):
exec_name = exec_info.func.__name__
else:
Expand Down Expand Up @@ -210,6 +213,11 @@ class RunDtoNoPassPreRunCapsule(_RunDtoBase, Generic[P, T]):
func: Callable[P, T] | None = None


@dataclass
class RunDtoCommand(_RunDtoBase):
command: tuple[str, ...] | None = None


class Run(Generic[P, T]):
_thread_local = threading.local()

Expand All @@ -228,7 +236,7 @@ def get_current(cls) -> Self:

def __init__(
self,
run_dto: RunDtoPassPreRunCapsule[P, T] | RunDtoNoPassPreRunCapsule[P, T],
run_dto: RunDtoPassPreRunCapsule[P, T] | RunDtoNoPassPreRunCapsule[P, T] | RunDtoCommand,
/,
) -> None:
self._pre_run_context_generators = run_dto.pre_run_context_generators
Expand All @@ -241,10 +249,6 @@ def __init__(

self._pass_pre_run_capsule: bool = isinstance(run_dto, RunDtoPassPreRunCapsule)

if run_dto.func is None:
raise CapsulaUninitializedError("func")
self._func: Callable[P, T] | Callable[Concatenate[Capsule, P], T] = run_dto.func

if run_dto.run_name_factory is None:
raise CapsulaUninitializedError("run_name_factory")
self._run_name_factory: Callable[[ExecInfo | None, str, datetime], str] = run_dto.run_name_factory
Expand All @@ -255,6 +259,21 @@ def __init__(

self._run_dir: Path | None = None

if isinstance(run_dto, RunDtoCommand):
if run_dto.command is None:
raise CapsulaUninitializedError("command")
self._func: Callable[P, T] | Callable[Concatenate[Capsule, P], T] | None = None
self._command: tuple[str, ...] | None = run_dto.command
elif isinstance(run_dto, (RunDtoPassPreRunCapsule, RunDtoNoPassPreRunCapsule)):
if run_dto.func is None:
raise CapsulaUninitializedError("func")
self._func = run_dto.func
self._command = None
else:
msg = "run_dto must be an instance of RunDtoCommand, RunDtoPassPreRunCapsule, or RunDtoNoPassPreRunCapsule,"
" not {type(run_dto)}."
raise TypeError(msg)

@property
def run_dir(self) -> Path:
if self._run_dir is None:
Expand All @@ -274,7 +293,7 @@ def __exit__(
) -> None:
self._get_run_stack().get(block=False)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR0912, PLR0915
def pre_run(self, exec_info: ExecInfo) -> tuple[CapsuleParams, Capsule]:
if self._vault_dir.exists():
if not self._vault_dir.is_dir():
msg = f"Vault directory {self._vault_dir} exists but is not a directory."
Expand All @@ -288,11 +307,9 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR09
gitignore_file.write("*\n")
logger.info(f"Vault directory: {self._vault_dir}")

func_info = FuncInfo(func=self._func, args=args, kwargs=kwargs, pass_pre_run_capsule=self._pass_pre_run_capsule)

# Generate the run name
run_name = self._run_name_factory(
func_info,
exec_info,
"".join(choices(ascii_letters + digits, k=4)),
datetime.now(timezone.utc),
)
Expand All @@ -310,11 +327,11 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR09
logger.info(f"Run directory: {self._run_dir}")

params = CapsuleParams(
exec_info=func_info,
exec_info=exec_info,
run_name=run_name,
run_dir=self._run_dir,
phase="pre",
project_root=get_project_root(func_info),
project_root=get_project_root(exec_info),
)

pre_run_enc = Encapsulator()
Expand All @@ -326,38 +343,83 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR09
reporter = reporter_generator(params)
reporter.report(pre_run_capsule)

return params, pre_run_capsule

def post_run(self, params: CapsuleParams) -> Capsule:
params.phase = "post"
post_run_enc = Encapsulator()
for context_generator in self._post_run_context_generators:
context = context_generator(params)
post_run_enc.add_context(context)
post_run_capsule = post_run_enc.encapsulate()
for reporter_generator in self._post_run_reporter_generators:
reporter = reporter_generator(params)
try:
reporter.report(post_run_capsule)
except Exception:
logger.exception(f"Failed to report post-run capsule with reporter {reporter}.")

return post_run_capsule

def in_run(self, params: CapsuleParams, func: Callable[[], _T]) -> _T:
params.phase = "in"
in_run_enc = Encapsulator()
for watcher_generator in self._in_run_watcher_generators:
watcher = watcher_generator(params)
in_run_enc.add_watcher(watcher)

with self, in_run_enc, in_run_enc.watch():
result = func()

in_run_capsule = in_run_enc.encapsulate()
for reporter_generator in self._in_run_reporter_generators:
reporter = reporter_generator(params)
try:
reporter.report(in_run_capsule)
except Exception:
logger.exception(f"Failed to report in-run capsule with reporter {reporter}.")

return result

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
assert self._func is not None
func_info = FuncInfo(func=self._func, args=args, kwargs=kwargs, pass_pre_run_capsule=self._pass_pre_run_capsule)
params, pre_run_capsule = self.pre_run(func_info)

if self._pass_pre_run_capsule:

def _func_1() -> T:
assert self._func is not None
return self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type]

func = _func_1
else:

def _func_2() -> T:
assert self._func is not None
return self._func(*args, **kwargs)

func = _func_2

try:
with self, in_run_enc, in_run_enc.watch():
if self._pass_pre_run_capsule:
result = self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type]
else:
result = self._func(*args, **kwargs)
result = self.in_run(params, func)
finally:
in_run_capsule = in_run_enc.encapsulate()
for reporter_generator in self._in_run_reporter_generators:
reporter = reporter_generator(params)
try:
reporter.report(in_run_capsule)
except Exception:
logger.exception(f"Failed to report in-run capsule with reporter {reporter}.")

params.phase = "post"
post_run_enc = Encapsulator()
for context_generator in self._post_run_context_generators:
context = context_generator(params)
post_run_enc.add_context(context)
post_run_capsule = post_run_enc.encapsulate()
for reporter_generator in self._post_run_reporter_generators:
reporter = reporter_generator(params)
try:
reporter.report(post_run_capsule)
except Exception:
logger.exception(f"Failed to report post-run capsule with reporter {reporter}.")
_post_run_capsule = self.post_run(params)

return result

def exec_command(self) -> tuple[subprocess.CompletedProcess[str], CapsuleParams]:
assert self._command is not None
command_info = CommandInfo(command=self._command)
params, _pre_run_capsule = self.pre_run(command_info)

def func() -> subprocess.CompletedProcess[str]:
assert self._command is not None
return subprocess.run(self._command, check=False, capture_output=True, text=True) # noqa: S603

try:
result = self.in_run(params, func)
finally:
_post_run_capsule = self.post_run(params)

return result, params
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ ignore = [
]
"scripts/*" = ["INP001", "EXE003", "T201"]
"tests/**" = ["SLF001"]
"capsula/_cli.py" = ["UP007", "FBT002", "UP006", "TCH001"]

0 comments on commit ec48ea8

Please sign in to comment.