Skip to content

Commit

Permalink
Merge pull request #174 from shunichironomura/fix-mypy-strict
Browse files Browse the repository at this point in the history
Update type hints for mypy to run in strict mode
  • Loading branch information
shunichironomura authored Feb 4, 2024
2 parents 8e66442 + a883a6d commit 1735303
Show file tree
Hide file tree
Showing 17 changed files with 110 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
run: poetry install --no-interaction

- name: Run mypy
run: poetry run --no-interaction mypy .
run: poetry run --no-interaction mypy --strict .

tox:
if: github.event.pull_request.draft == false
Expand Down
13 changes: 9 additions & 4 deletions capsula/_backport.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

__all__ = ["Concatenate", "ParamSpec", "Self", "TypeAlias", "file_digest"]
__all__ = ["AbstractContextManager", "Concatenate", "ParamSpec", "Self", "TypeAlias", "file_digest"]

import hashlib
import sys
Expand All @@ -16,16 +16,21 @@
else:
from typing_extensions import Concatenate, ParamSpec, TypeAlias

if sys.version_info >= (3, 9):
from contextlib import AbstractContextManager
else:
from typing import ContextManager as AbstractContextManager


if sys.version_info >= (3, 11):
file_digest = hashlib.file_digest
else:
if TYPE_CHECKING:
import io
from typing_extensions import Buffer
from typing import Protocol

class _BytesIOLike(Protocol):
def getbuffer(self) -> io.ReadableBuffer:
def getbuffer(self) -> Buffer:
...

class _FileDigestFileObj(Protocol):
Expand All @@ -41,7 +46,7 @@ def file_digest(
/,
*,
_bufsize: int = 2**18,
):
) -> hashlib._Hash:
"""Hash the contents of a file-like object. Returns a digest object.
*fileobj* must be a file-like object opened for reading in binary mode.
Expand Down
2 changes: 1 addition & 1 deletion capsula/_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from collections.abc import Mapping

from capsula._decorator import CapsuleParams
from capsula._run import CapsuleParams
from capsula.utils import ExceptionInfo

from ._backport import Self, TypeAlias
Expand Down
12 changes: 10 additions & 2 deletions capsula/_context/_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import subprocess
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -12,13 +12,21 @@
logger = logging.getLogger(__name__)


class _CommandContextData(TypedDict):
command: str
cwd: Path | None
returncode: int
stdout: str
stderr: str


class CommandContext(ContextBase):
def __init__(self, command: str, *, cwd: Path | None = None, check: bool = False) -> None:
self.command = command
self.cwd = cwd
self.check = check

def encapsulate(self) -> dict:
def encapsulate(self) -> _CommandContextData:
logger.debug(f"Running command: {self.command}")
output = subprocess.run(
self.command,
Expand Down
8 changes: 6 additions & 2 deletions capsula/_context/_cpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

from typing import Any

from cpuinfo import get_cpu_info

from ._base import ContextBase


class CpuContext(ContextBase):
def encapsulate(self) -> dict:
return get_cpu_info()
def encapsulate(self) -> dict[str, Any]:
return get_cpu_info() # type: ignore[no-any-return]

def default_key(self) -> str:
return "cpu"
29 changes: 19 additions & 10 deletions capsula/_context/_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
import warnings
from pathlib import Path
from shutil import copyfile, move
from typing import TYPE_CHECKING, Callable, Iterable
from typing import TYPE_CHECKING, Callable, Iterable, TypedDict

from capsula._backport import file_digest

from ._base import ContextBase

if TYPE_CHECKING:
from capsula._decorator import CapsuleParams
from capsula._run import CapsuleParams

logger = logging.getLogger(__name__)


class _FileContextData(TypedDict):
copied_to: tuple[Path, ...]
moved_to: Path | None
hash: dict[str, str] | None


class FileContext(ContextBase):
_default_hash_algorithm = "sha256"

Expand Down Expand Up @@ -46,26 +52,29 @@ def _normalize_copy_dst_path(self, p: Path) -> Path:
else:
return p

def encapsulate(self) -> dict:
def encapsulate(self) -> _FileContextData:
self.copy_to = tuple(self._normalize_copy_dst_path(p) for p in self.copy_to)

info: dict = {
"copied_to": self.copy_to,
"moved_to": self.move_to,
}

if self.compute_hash:
with self.path.open("rb") as f:
digest = file_digest(f, self.hash_algorithm).hexdigest()
info["hash"] = {
hash_data = {
"algorithm": self.hash_algorithm,
"digest": digest,
}
else:
hash_data = None

info: _FileContextData = {
"copied_to": self.copy_to,
"moved_to": self.move_to,
"hash": hash_data,
}

for path in self.copy_to:
copyfile(self.path, path)
if self.move_to is not None:
move(self.path, self.move_to)
move(str(self.path), self.move_to)

return info

Expand Down
13 changes: 10 additions & 3 deletions capsula/_context/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@

import inspect
from pathlib import Path
from typing import Any, Callable, Mapping, Sequence
from typing import Any, Callable, Mapping, Sequence, TypedDict

from ._base import ContextBase


class _FunctionCallContextData(TypedDict):
file_path: Path
first_line_no: int
args: Sequence[Any]
kwargs: Mapping[str, Any]


class FunctionCallContext(ContextBase):
def __init__(self, function: Callable, args: Sequence[Any], kwargs: Mapping[str, Any]) -> None:
def __init__(self, function: Callable[..., Any], args: Sequence[Any], kwargs: Mapping[str, Any]) -> None:
self.function = function
self.args = args
self.kwargs = kwargs

def encapsulate(self) -> dict:
def encapsulate(self) -> _FunctionCallContextData:
file_path = Path(inspect.getfile(self.function))
_, first_line_no = inspect.getsourcelines(self.function)
return {
Expand Down
18 changes: 14 additions & 4 deletions capsula/_context/_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, TypedDict

from git.repo import Repo

Expand All @@ -12,7 +12,9 @@
from ._base import ContextBase

if TYPE_CHECKING:
from capsula._decorator import CapsuleParams
from os import PathLike

from capsula._run import CapsuleParams

logger = logging.getLogger(__name__)

Expand All @@ -23,6 +25,14 @@ def __init__(self, repo: Repo) -> None:
super().__init__(f"Repository {repo.working_dir} is dirty")


class _GitRepositoryContextData(TypedDict):
working_dir: PathLike[str] | str
sha: str
remotes: dict[str, str]
branch: str
is_dirty: bool


class GitRepositoryContext(ContextBase):
def __init__(
self,
Expand All @@ -39,12 +49,12 @@ def __init__(
self.allow_dirty = allow_dirty
self.diff_file = None if diff_file is None else Path(diff_file)

def encapsulate(self) -> dict:
def encapsulate(self) -> _GitRepositoryContextData:
repo = Repo(self.path, search_parent_directories=self.search_parent_directories)
if not self.allow_dirty and repo.is_dirty():
raise GitRepositoryDirtyError(repo)

info = {
info: _GitRepositoryContextData = {
"working_dir": repo.working_dir,
"sha": repo.head.commit.hexsha,
"remotes": {remote.name: remote.url for remote in repo.remotes},
Expand Down
16 changes: 15 additions & 1 deletion capsula/_context/_platform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from __future__ import annotations

import platform as pf
from typing import TypedDict

from ._base import ContextBase


class _PlatformContextData(TypedDict):
machine: str
node: str
platform: str
release: str
version: str
system: str
processor: str
python: dict[str, str | dict[str, str]]


class PlatformContext(ContextBase):
def encapsulate(self) -> dict:
def encapsulate(self) -> _PlatformContextData:
return {
"machine": pf.machine(),
"node": pf.node(),
Expand Down
2 changes: 1 addition & 1 deletion capsula/_reporter/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod

from capsula.encapsulator import Capsule
from capsula._capsule import Capsule


class ReporterBase(ABC):
Expand Down
4 changes: 2 additions & 2 deletions capsula/_reporter/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from capsula.utils import to_nested_dict

if TYPE_CHECKING:
from capsula._decorator import CapsuleParams
from capsula.encapsulator import Capsule
from capsula._capsule import Capsule
from capsula._run import CapsuleParams

from ._base import ReporterBase

Expand Down
2 changes: 1 addition & 1 deletion capsula/_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def record(key: _CapsuleItemKey, value: Any) -> None:


def current_run_name() -> str:
run: Run | None = Run.get_current()
run: Run[Any, Any] | None = Run.get_current()
if run is None:
msg = "No active run found."
raise RuntimeError(msg)
Expand Down
17 changes: 11 additions & 6 deletions capsula/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import queue
import threading
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Generic, Literal, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Literal, Tuple, TypeVar, overload

from pydantic import BaseModel

Expand All @@ -24,9 +24,9 @@


class FuncInfo(BaseModel):
func: Callable
args: tuple
kwargs: dict
func: Callable[..., Any]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]


class CapsuleParams(FuncInfo):
Expand All @@ -41,7 +41,7 @@ class Run(Generic[_P, _T]):
def _get_run_stack(cls) -> queue.LifoQueue[Self]:
if not hasattr(cls._thread_local, "run_stack"):
cls._thread_local.run_stack = queue.LifoQueue()
return cls._thread_local.run_stack
return cls._thread_local.run_stack # type: ignore[no-any-return]

@classmethod
def get_current(cls) -> Self | None:
Expand All @@ -58,7 +58,12 @@ def __init__(self, func: Callable[_P, _T], *, pass_pre_run_capsule: Literal[Fals
def __init__(self, func: Callable[Concatenate[Capsule, _P], _T], *, pass_pre_run_capsule: Literal[True]) -> None:
...

def __init__(self, func, *, pass_pre_run_capsule: bool = False) -> None:
def __init__(
self,
func: Callable[_P, _T] | Callable[Concatenate[Capsule, _P], _T],
*,
pass_pre_run_capsule: bool = False,
) -> None:
self.pre_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = []
self.in_run_watcher_generators: list[Callable[[CapsuleParams], WatcherBase]] = []
self.post_run_context_generators: list[Callable[[CapsuleParams], ContextBase]] = []
Expand Down
10 changes: 5 additions & 5 deletions capsula/encapsulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import threading
from collections import OrderedDict
from collections.abc import Hashable
from contextlib import AbstractContextManager
from itertools import chain
from typing import TYPE_CHECKING, Any, Generic, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Generic, Tuple, TypeVar, Union

from capsula.utils import ExceptionInfo

from ._backport import AbstractContextManager
from ._capsule import Capsule
from ._context import ContextBase
from ._watcher import WatcherBase
Expand Down Expand Up @@ -40,10 +40,10 @@ def encapsulate(self) -> Any:
_V = TypeVar("_V", bound=WatcherBase)


class WatcherGroup(AbstractContextManager, Generic[_K, _V]):
class WatcherGroup(Generic[_K, _V], AbstractContextManager[Dict[_K, Any]]):
def __init__(self, watchers: OrderedDict[_K, _V]) -> None:
self.watchers = watchers
self.context_manager_stack: queue.LifoQueue[AbstractContextManager] = queue.LifoQueue()
self.context_manager_stack: queue.LifoQueue[AbstractContextManager[None]] = queue.LifoQueue()

def __enter__(self) -> dict[_K, Any]:
self.context_manager_stack = queue.LifoQueue()
Expand Down Expand Up @@ -83,7 +83,7 @@ class Encapsulator:
def _get_context_stack(cls) -> queue.LifoQueue[Self]:
if not hasattr(cls._thread_local, "context_stack"):
cls._thread_local.context_stack = queue.LifoQueue()
return cls._thread_local.context_stack
return cls._thread_local.context_stack # type: ignore[no-any-return]

@classmethod
def get_current(cls) -> Self | None:
Expand Down
2 changes: 1 addition & 1 deletion coverage/badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 1735303

Please sign in to comment.