diff --git a/capsula/_cli.py b/capsula/_cli.py index 278dc318..1e63a477 100644 --- a/capsula/_cli.py +++ b/capsula/_cli.py @@ -1,10 +1,16 @@ +from __future__ import annotations + +import logging +import shlex from datetime import datetime, timezone from enum import Enum +from pathlib import Path from random import choices from string import ascii_letters, digits -from typing import Literal, NoReturn +from typing import Any, List, Literal, NoReturn, Optional import typer +from rich.console import Console import capsula @@ -13,20 +19,92 @@ from ._context import ContextBase from ._run import ( CapsuleParams, + Run, + RunDtoCommand, default_run_name_factory, get_default_vault_dir, get_project_root, ) -from ._utils import get_default_config_path +from ._utils import get_default_config_path, search_for_project_root + +logger = logging.getLogger(__name__) app = typer.Typer() +console = Console() +err_console = Console(stderr=True) @app.command() -def run() -> NoReturn: - typer.echo("Running...") +def run( + command: Annotated[List[str], typer.Argument(help="Command to run", show_default=False)], + run_name: Annotated[ + Optional[str], + typer.Option( + ..., + help="Run name. Make sure it is unique. If not provided, it will be generated randomly.", + ), + ] = None, + vault_dir: Annotated[ + Optional[Path], + typer.Option( + ..., + help="Vault directory. If not provided, it will be set to the default value.", + ), + ] = None, + ignore_config: Annotated[ + bool, + typer.Option( + ..., + help="Ignore the configuration file and run the command directly.", + ), + ] = False, + config_path: Annotated[ + Optional[Path], + typer.Option( + ..., + help="Path to the Capsula configuration file.", + ), + ] = None, +) -> NoReturn: + err_console.print(f"Running command '{shlex.join(command)}'...") + run_dto = RunDtoCommand( + run_name_factory=default_run_name_factory if run_name is None else lambda _x, _y, _z: run_name, + vault_dir=vault_dir, + command=tuple(command), + ) - raise typer.Exit + if not ignore_config: + config = load_config(get_default_config_path() if config_path is None else config_path) + for phase in ("pre", "in", "post"): + phase_key = f"{phase}-run" + if phase_key not in config: + continue + for context in reversed(config[phase_key].get("contexts", [])): # type: ignore[literal-required] + assert phase in {"pre", "post"}, f"Invalid phase for context: {phase}" + run_dto.add_context(context, mode=phase, append_left=True) # type: ignore[arg-type] + for watcher in reversed(config[phase_key].get("watchers", [])): # type: ignore[literal-required] + assert phase == "in", "Watcher can only be added to the in-run phase." + # No need to set append_left=True here, as watchers are added as the outermost context manager + run_dto.add_watcher(watcher, append_left=False) + for reporter in reversed(config[phase_key].get("reporters", [])): # type: ignore[literal-required] + assert phase in {"pre", "in", "post"}, f"Invalid phase for reporter: {phase}" + run_dto.add_reporter(reporter, mode=phase, append_left=True) # type: ignore[arg-type] + + run_dto.vault_dir = config["vault-dir"] if run_dto.vault_dir is None else run_dto.vault_dir + + # Set the vault directory if it is not set by the config file + if run_dto.vault_dir is None: + project_root = search_for_project_root(Path.cwd()) + run_dto.vault_dir = project_root / "vault" + + run: Run[Any, Any] = Run(run_dto) + result, params = run.exec_command() + console.print(result.stdout, end="") + err_console.print(result.stderr, end="") + err_console.print(f"Run directory: {params.run_dir}") + err_console.print(f"Command exited with code {result.returncode}") + + raise typer.Exit(result.returncode) class _PhaseForEncapsulate(str, Enum): diff --git a/capsula/_run.py b/capsula/_run.py index 191586d2..44e84741 100644 --- a/capsula/_run.py +++ b/capsula/_run.py @@ -3,6 +3,7 @@ import inspect import logging import queue +import subprocess import threading from collections import deque from dataclasses import dataclass, field @@ -29,6 +30,8 @@ P = ParamSpec("P") T = TypeVar("T") +_P = ParamSpec("_P") +_T = TypeVar("_T") logger = logging.getLogger(__name__) @@ -50,7 +53,7 @@ def bound_args(self) -> OrderedDict[str, Any]: @dataclass class CommandInfo: - command: str + command: tuple[str, ...] @dataclass @@ -75,7 +78,7 @@ def default_run_name_factory(exec_info: ExecInfo | None, random_str: str, timest if exec_info is None: exec_name = None elif isinstance(exec_info, CommandInfo): - exec_name = exec_info.command.split()[0] # TODO: handle more complex commands + exec_name = exec_info.command[0] elif isinstance(exec_info, FuncInfo): exec_name = exec_info.func.__name__ else: @@ -210,6 +213,11 @@ class RunDtoNoPassPreRunCapsule(_RunDtoBase, Generic[P, T]): func: Callable[P, T] | None = None +@dataclass +class RunDtoCommand(_RunDtoBase): + command: tuple[str, ...] | None = None + + class Run(Generic[P, T]): _thread_local = threading.local() @@ -228,7 +236,7 @@ def get_current(cls) -> Self: def __init__( self, - run_dto: RunDtoPassPreRunCapsule[P, T] | RunDtoNoPassPreRunCapsule[P, T], + run_dto: RunDtoPassPreRunCapsule[P, T] | RunDtoNoPassPreRunCapsule[P, T] | RunDtoCommand, /, ) -> None: self._pre_run_context_generators = run_dto.pre_run_context_generators @@ -241,10 +249,6 @@ def __init__( self._pass_pre_run_capsule: bool = isinstance(run_dto, RunDtoPassPreRunCapsule) - if run_dto.func is None: - raise CapsulaUninitializedError("func") - self._func: Callable[P, T] | Callable[Concatenate[Capsule, P], T] = run_dto.func - if run_dto.run_name_factory is None: raise CapsulaUninitializedError("run_name_factory") self._run_name_factory: Callable[[ExecInfo | None, str, datetime], str] = run_dto.run_name_factory @@ -255,6 +259,21 @@ def __init__( self._run_dir: Path | None = None + if isinstance(run_dto, RunDtoCommand): + if run_dto.command is None: + raise CapsulaUninitializedError("command") + self._func: Callable[P, T] | Callable[Concatenate[Capsule, P], T] | None = None + self._command: tuple[str, ...] | None = run_dto.command + elif isinstance(run_dto, (RunDtoPassPreRunCapsule, RunDtoNoPassPreRunCapsule)): + if run_dto.func is None: + raise CapsulaUninitializedError("func") + self._func = run_dto.func + self._command = None + else: + msg = "run_dto must be an instance of RunDtoCommand, RunDtoPassPreRunCapsule, or RunDtoNoPassPreRunCapsule," + " not {type(run_dto)}." + raise TypeError(msg) + @property def run_dir(self) -> Path: if self._run_dir is None: @@ -274,7 +293,7 @@ def __exit__( ) -> None: self._get_run_stack().get(block=False) - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR0912, PLR0915 + def pre_run(self, exec_info: ExecInfo) -> tuple[CapsuleParams, Capsule]: if self._vault_dir.exists(): if not self._vault_dir.is_dir(): msg = f"Vault directory {self._vault_dir} exists but is not a directory." @@ -288,11 +307,9 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR09 gitignore_file.write("*\n") logger.info(f"Vault directory: {self._vault_dir}") - func_info = FuncInfo(func=self._func, args=args, kwargs=kwargs, pass_pre_run_capsule=self._pass_pre_run_capsule) - # Generate the run name run_name = self._run_name_factory( - func_info, + exec_info, "".join(choices(ascii_letters + digits, k=4)), datetime.now(timezone.utc), ) @@ -310,11 +327,11 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR09 logger.info(f"Run directory: {self._run_dir}") params = CapsuleParams( - exec_info=func_info, + exec_info=exec_info, run_name=run_name, run_dir=self._run_dir, phase="pre", - project_root=get_project_root(func_info), + project_root=get_project_root(exec_info), ) pre_run_enc = Encapsulator() @@ -326,38 +343,83 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR09 reporter = reporter_generator(params) reporter.report(pre_run_capsule) + return params, pre_run_capsule + + def post_run(self, params: CapsuleParams) -> Capsule: + params.phase = "post" + post_run_enc = Encapsulator() + for context_generator in self._post_run_context_generators: + context = context_generator(params) + post_run_enc.add_context(context) + post_run_capsule = post_run_enc.encapsulate() + for reporter_generator in self._post_run_reporter_generators: + reporter = reporter_generator(params) + try: + reporter.report(post_run_capsule) + except Exception: + logger.exception(f"Failed to report post-run capsule with reporter {reporter}.") + + return post_run_capsule + + def in_run(self, params: CapsuleParams, func: Callable[[], _T]) -> _T: params.phase = "in" in_run_enc = Encapsulator() for watcher_generator in self._in_run_watcher_generators: watcher = watcher_generator(params) in_run_enc.add_watcher(watcher) + with self, in_run_enc, in_run_enc.watch(): + result = func() + + in_run_capsule = in_run_enc.encapsulate() + for reporter_generator in self._in_run_reporter_generators: + reporter = reporter_generator(params) + try: + reporter.report(in_run_capsule) + except Exception: + logger.exception(f"Failed to report in-run capsule with reporter {reporter}.") + + return result + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: + assert self._func is not None + func_info = FuncInfo(func=self._func, args=args, kwargs=kwargs, pass_pre_run_capsule=self._pass_pre_run_capsule) + params, pre_run_capsule = self.pre_run(func_info) + + if self._pass_pre_run_capsule: + + def _func_1() -> T: + assert self._func is not None + return self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type] + + func = _func_1 + else: + + def _func_2() -> T: + assert self._func is not None + return self._func(*args, **kwargs) + + func = _func_2 + try: - with self, in_run_enc, in_run_enc.watch(): - if self._pass_pre_run_capsule: - result = self._func(pre_run_capsule, *args, **kwargs) # type: ignore[arg-type] - else: - result = self._func(*args, **kwargs) + result = self.in_run(params, func) finally: - in_run_capsule = in_run_enc.encapsulate() - for reporter_generator in self._in_run_reporter_generators: - reporter = reporter_generator(params) - try: - reporter.report(in_run_capsule) - except Exception: - logger.exception(f"Failed to report in-run capsule with reporter {reporter}.") - - params.phase = "post" - post_run_enc = Encapsulator() - for context_generator in self._post_run_context_generators: - context = context_generator(params) - post_run_enc.add_context(context) - post_run_capsule = post_run_enc.encapsulate() - for reporter_generator in self._post_run_reporter_generators: - reporter = reporter_generator(params) - try: - reporter.report(post_run_capsule) - except Exception: - logger.exception(f"Failed to report post-run capsule with reporter {reporter}.") + _post_run_capsule = self.post_run(params) return result + + def exec_command(self) -> tuple[subprocess.CompletedProcess[str], CapsuleParams]: + assert self._command is not None + command_info = CommandInfo(command=self._command) + params, _pre_run_capsule = self.pre_run(command_info) + + def func() -> subprocess.CompletedProcess[str]: + assert self._command is not None + return subprocess.run(self._command, check=False, capture_output=True, text=True) # noqa: S603 + + try: + result = self.in_run(params, func) + finally: + _post_run_capsule = self.post_run(params) + + return result, params diff --git a/pyproject.toml b/pyproject.toml index cdb16587..4de745da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,3 +148,4 @@ ignore = [ ] "scripts/*" = ["INP001", "EXE003", "T201"] "tests/**" = ["SLF001"] +"capsula/_cli.py" = ["UP007", "FBT002", "UP006", "TCH001"]