From 70657af561abc2f6dc60998fdedded4833c23177 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 6 Mar 2024 15:22:16 -0400 Subject: [PATCH] Add type information and require types --- client_utils/cli/download_lcp_manifest.py | 4 ++-- client_utils/cli/palace_terminal.py | 4 ++-- client_utils/cli/patron_bookshelf.py | 7 ++++--- .../cli/summarize_rwpm_audio_manifest.py | 2 +- .../models/api/authentication_document.py | 3 ++- client_utils/models/api/opds2.py | 16 ++++++++-------- client_utils/models/api/rwpm_audiobook.py | 4 ++-- .../models/internal/rwpm_audio/audio_segment.py | 2 +- .../models/internal/rwpm_audio/audiobook.py | 10 +++++----- client_utils/roles/patron.py | 10 +++++++--- client_utils/utils/http/async_client.py | 17 ++++++++++------- client_utils/utils/http/streaming.py | 15 +++++++++------ client_utils/utils/typer.py | 2 +- pyproject.toml | 12 ++++++++---- 14 files changed, 62 insertions(+), 46 deletions(-) diff --git a/client_utils/cli/download_lcp_manifest.py b/client_utils/cli/download_lcp_manifest.py index 47f9ed7..5c5aa35 100755 --- a/client_utils/cli/download_lcp_manifest.py +++ b/client_utils/cli/download_lcp_manifest.py @@ -66,8 +66,8 @@ async def process_command( manifest_file: Path | str | int, username: str, password: str | None, - manifest_member_name="manifest.json", - pretty_print=False, + manifest_member_name: str = "manifest.json", + pretty_print: bool = False, ) -> None: client_headers = {"User-Agent": "Palace"} token: BaseAuthorizationToken = BasicAuthToken.from_username_and_password( diff --git a/client_utils/cli/palace_terminal.py b/client_utils/cli/palace_terminal.py index 0a4f00d..a41d9ee 100644 --- a/client_utils/cli/palace_terminal.py +++ b/client_utils/cli/palace_terminal.py @@ -431,7 +431,7 @@ def init_handlers(self) -> None: ) -class MediaPlayerUi(App): +class MediaPlayerUi(App[None]): TITLE = "PALACE - Terminal Edition" BINDINGS = [ ("p", "play", "Play Media"), @@ -601,7 +601,7 @@ def update_player_info(self) -> None: log(traceback.format_exc()) @staticmethod - def set_playing_row(table: DataTable, key: str) -> None: + def set_playing_row(table: DataTable[str], key: str) -> None: for row_key in table.rows.keys(): if row_key.value == key: updated_str = "▶️" diff --git a/client_utils/cli/patron_bookshelf.py b/client_utils/cli/patron_bookshelf.py index b672fd8..e5fb891 100755 --- a/client_utils/cli/patron_bookshelf.py +++ b/client_utils/cli/patron_bookshelf.py @@ -6,6 +6,7 @@ import typer from client_utils.constants import DEFAULT_REGISTRY_URL +from client_utils.models.api.opds2 import OPDS2Feed from client_utils.models.internal.bookshelf import print_bookshelf_summary from client_utils.roles.patron import authenticate from client_utils.utils.http.async_client import HTTPXAsyncClient @@ -14,7 +15,7 @@ app = typer.Typer(rich_markup_mode="rich") -def main(): +def main() -> None: run_typer_app_as_main(app) @@ -78,7 +79,7 @@ def patron_bookshelf( help="Output bookshelf as JSON.", rich_help_panel="Output", ), -): +) -> None: bookshelf = asyncio.run( fetch_bookshelf( username=username, @@ -105,7 +106,7 @@ async def fetch_bookshelf( opds_server: str | None = None, auth_doc_url: str | None = None, allow_hidden_libraries: bool = False, -): +) -> OPDS2Feed: async with HTTPXAsyncClient() as client: patron = await authenticate( username=username, diff --git a/client_utils/cli/summarize_rwpm_audio_manifest.py b/client_utils/cli/summarize_rwpm_audio_manifest.py index 4001253..4c41027 100755 --- a/client_utils/cli/summarize_rwpm_audio_manifest.py +++ b/client_utils/cli/summarize_rwpm_audio_manifest.py @@ -46,7 +46,7 @@ def format_delta(delta: int | float, delta_suffix: str | None = None) -> str: def text_with_time_delta( text: str, delta_secs: int | float, - delta_label="duration", + delta_label: str = "duration", delta_suffix: str | None = None, second_delta: int | float | None = None, second_delta_suffix: str | None = None, diff --git a/client_utils/models/api/authentication_document.py b/client_utils/models/api/authentication_document.py index fd115bd..dcdce38 100644 --- a/client_utils/models/api/authentication_document.py +++ b/client_utils/models/api/authentication_document.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Mapping +from typing import Any from client_utils.constants import PATRON_BOOKSHELF_REL, PATRON_PROFILE_REL from client_utils.models.api.opds2 import OPDS2Link, match_links @@ -61,7 +62,7 @@ class AuthenticationDocument(ApiBaseModel): authentication: list[AuthenticationMechanism] features: Features links: list[OPDS2Link] - announcements: list + announcements: list[Any] service_description: str public_key: PublicKey # color_scheme: str diff --git a/client_utils/models/api/opds2.py b/client_utils/models/api/opds2.py index 644e7a0..43aabc2 100644 --- a/client_utils/models/api/opds2.py +++ b/client_utils/models/api/opds2.py @@ -33,7 +33,7 @@ class OPDS2Link(ApiBaseModel): templated: bool = False @property - def is_acquisition(self): + def is_acquisition(self) -> bool: """Is this an acquisition link?""" return any( rel in [OPDS_ACQ_STANDARD_REL, OPDS_ACQ_OPEN_ACCESS_REL] @@ -41,18 +41,18 @@ def is_acquisition(self): ) @property - def has_indirect_acquisition(self): + def has_indirect_acquisition(self) -> bool: """Does link have one or more indirect acquisition links?""" return bool(self.indirect_acquisition_links) @property def indirect_acquisition_links(self) -> Sequence[Mapping[str, Any]]: """Indirect acquisition link, if any.""" - return vars(self).get("properties", {}).get("indirectAcquisition", []) + return vars(self).get("properties", {}).get("indirectAcquisition", []) # type: ignore[no-any-return] class OPDS2(ApiBaseModel): - catalogs: list = [] + catalogs: list[Any] = [] links: list[OPDS2Link] = [] metadata: Mapping[str, Any] @@ -116,14 +116,14 @@ class Publication(ApiBaseModel): images: list[Image] = [] @property - def acquisition_links(self): + def acquisition_links(self) -> list[OPDS2Link]: return match_links( self.links, lambda link: link.rel in [OPDS_ACQ_STANDARD_REL, OPDS_ACQ_OPEN_ACCESS_REL], ) @property - def revoke_links(self): + def revoke_links(self) -> list[OPDS2Link]: return match_links(self.links, lambda link: link.rel == OPDS_REVOKE_REL) @property @@ -138,10 +138,10 @@ class FeedMetadata(ApiBaseModel): class OPDS2Feed(ApiBaseModel): publications: list[Publication] = [] - catalogs: list = [] + catalogs: list[Any] = [] links: list[OPDS2Link] = [] metadata: Mapping[str, Any] - facets: list + facets: list[Any] L = TypeVar("L", bound=Mapping[str, str] | OPDS2Link) diff --git a/client_utils/models/api/rwpm_audiobook.py b/client_utils/models/api/rwpm_audiobook.py index ec2fd44..c21045c 100644 --- a/client_utils/models/api/rwpm_audiobook.py +++ b/client_utils/models/api/rwpm_audiobook.py @@ -35,7 +35,7 @@ def toc_in_playback_order(self) -> Generator[ToCEntry, None, None]: yield from child.toc_in_playback_order() @classmethod - def from_track(cls, track: AudioTrack, default_title="Track") -> Self: + def from_track(cls, track: AudioTrack, default_title: str = "Track") -> Self: """Create a ToCEntry from an AudioTrack.""" toc_href = f"{track.href}#t=0" toc_title = track.title or default_title @@ -89,7 +89,7 @@ def validate_model(self) -> Self: return self @cached_property - def effective_toc(self): + def effective_toc(self) -> Sequence[ToCEntry]: return self.toc or [ ToCEntry.from_track(track=track, default_title=f"Track {n}") for n, track in enumerate(self.reading_order, start=1) diff --git a/client_utils/models/internal/rwpm_audio/audio_segment.py b/client_utils/models/internal/rwpm_audio/audio_segment.py index 269e1d7..b39fc59 100644 --- a/client_utils/models/internal/rwpm_audio/audio_segment.py +++ b/client_utils/models/internal/rwpm_audio/audio_segment.py @@ -19,7 +19,7 @@ class AudioSegment: duration: int = field(init=False) actual_duration: float = field(init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.duration = self.end - self.start self.actual_duration = self.end_actual - self.start diff --git a/client_utils/models/internal/rwpm_audio/audiobook.py b/client_utils/models/internal/rwpm_audio/audiobook.py index b88ee3e..5d8bdc8 100644 --- a/client_utils/models/internal/rwpm_audio/audiobook.py +++ b/client_utils/models/internal/rwpm_audio/audiobook.py @@ -30,7 +30,7 @@ class EnhancedToCEntry(ToCEntry): actual_duration: float = 0.0 @cached_property - def total_duration(self): + def total_duration(self) -> int: """The duration (in seconds) of this ToCEntry and its children.""" return sum(toc.duration for toc in self.enhanced_toc_in_playback_order) @@ -64,7 +64,7 @@ def generate_enhanced_toc( self, toc: Sequence[ToCEntry] | None, depth: int = 0, - ): + ) -> Sequence[EnhancedToCEntry]: """Recursively generate enhanced ToC entries.""" return ( [ @@ -124,12 +124,12 @@ def pre_toc_unplayed_audio_segments(self) -> Sequence[AudioSegment]: ).audio_segments @cached_property - def toc_total_duration(self): + def toc_total_duration(self) -> int: """The duration (in seconds) of this ToCEntry and its children.""" return sum(toc.duration for toc in self.enhanced_toc_in_playback_order) @cached_property - def toc_actual_total_duration(self): + def toc_actual_total_duration(self) -> float: """The duration (in seconds) of this ToCEntry and its children.""" return sum(toc.actual_duration for toc in self.enhanced_toc_in_playback_order) @@ -141,6 +141,6 @@ def from_manifest_file(cls, filepath: Path | str) -> Self: # Try to load the track track_file = directory_path / track.href if track_file.is_file(): - track.actual_duration = MP3(track_file).info.length # type: ignore[attr-defined] + track.actual_duration = MP3(track_file).info.length return cls(manifest=manifest) diff --git a/client_utils/roles/patron.py b/client_utils/roles/patron.py index cf28f11..9ded046 100644 --- a/client_utils/roles/patron.py +++ b/client_utils/roles/patron.py @@ -41,7 +41,9 @@ class PatronAuthorization: class AuthenticatedPatron(PatronAuthorization): authentication_document: AuthenticationDocument - async def patron_profile_document(self, http_client=None): + async def patron_profile_document( + self, http_client: HTTPXAsyncClient | None = None + ) -> PatronProfileDocument: [patron_profile_link] = self.authentication_document.patron_profile_links headers = dict(self.token.as_http_headers) async with HTTPXAsyncClient.with_existing_client( @@ -52,7 +54,9 @@ async def patron_profile_document(self, http_client=None): ).json() return PatronProfileDocument.model_validate(profile) - async def patron_bookshelf(self, http_client=None): + async def patron_bookshelf( + self, http_client: HTTPXAsyncClient | None = None + ) -> OPDS2Feed: [patron_bookshelf_link] = self.authentication_document.patron_bookshelf_links headers = dict(self.token.as_http_headers) | {"Accept": OPDS_2_TYPE} async with HTTPXAsyncClient.with_existing_client( @@ -74,7 +78,7 @@ async def authenticate( opds_server: str | None = None, allow_hidden_libraries: bool = False, http_client: HTTPXAsyncClient | None = None, -): +) -> AuthenticatedPatron: """Login as a patron.""" async with HTTPXAsyncClient.with_existing_client( existing_client=http_client diff --git a/client_utils/utils/http/async_client.py b/client_utils/utils/http/async_client.py index 11922aa..305caf3 100644 --- a/client_utils/utils/http/async_client.py +++ b/client_utils/utils/http/async_client.py @@ -2,6 +2,7 @@ import sys from contextlib import nullcontext +from typing import Any from httpx import AsyncClient, Response @@ -14,22 +15,24 @@ class HTTPXAsyncClient(AsyncClient): - def __init__(self, user_agent: str = DEFAULT_USER_AGENT, **kwargs): + def __init__(self, user_agent: str = DEFAULT_USER_AGENT, **kwargs: Any) -> None: super().__init__(**kwargs) self.user_agent = user_agent - async def request(self, method, url, *, headers=None, **kwargs) -> Response: + async def request(self, method: str, url: str, *, headers: dict[str, str] | None = None, **kwargs: Any) -> Response: # type: ignore[override] headers = {"User-Agent": self.user_agent} | (headers or {}) return await super().request(method, url, headers=headers, **kwargs) - async def post(self, *args, **kwargs) -> Response: + async def post(self, *args: Any, **kwargs: Any) -> Response: return await self.request("POST", *args, **kwargs) - async def get(self, *args, **kwargs) -> Response: + async def get(self, *args: Any, **kwargs: Any) -> Response: return await self.request("GET", *args, **kwargs) @classmethod - def with_existing_client(cls, *args, existing_client: Self | None = None, **kwargs): + def with_existing_client( + cls, *args: Any, existing_client: Self | None = None, **kwargs: Any + ) -> Self: """Return an instance of our self. :param existing_client: A client to use instead of creating a new one. @@ -39,12 +42,12 @@ def with_existing_client(cls, *args, existing_client: Self | None = None, **kwar If not, a new one will be instantiated, using the provided arguments. """ if existing_client: - return nullcontext(enter_result=existing_client) + return nullcontext(enter_result=existing_client) # type: ignore[return-value] else: return cls(*args, **kwargs) -def validate_response(response: Response, raise_for_status=True): +def validate_response(response: Response, raise_for_status: bool = True) -> Response: if raise_for_status: response.raise_for_status() return response diff --git a/client_utils/utils/http/streaming.py b/client_utils/utils/http/streaming.py index 52bf7dc..403cf72 100644 --- a/client_utils/utils/http/streaming.py +++ b/client_utils/utils/http/streaming.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Any, BinaryIO, ContextManager +from typing import Any, BinaryIO, ContextManager, TypeVar import rich.progress from httpx import Response @@ -18,7 +18,10 @@ def default_progress_bar() -> rich.progress.Progress: ) -def _to_list(value) -> list: +T = TypeVar("T") + + +def _to_list(value: Sequence[T] | T | None) -> list[T]: """Ensure that we end up with our own copy as a list. :param value: The value from which we'll create our list. @@ -45,8 +48,8 @@ async def streaming_fetch( | Sequence[Callable[[int], Any]] | None = None, http_client: HTTPXAsyncClient | None = None, - raise_for_status=False, -): + raise_for_status: bool = False, +) -> Response: async with HTTPXAsyncClient.with_existing_client(http_client) as client: async with client.stream("GET", url=url) as response: if raise_for_status: @@ -78,9 +81,9 @@ async def streaming_fetch_with_progress( total_setters: Callable[[int], Any] | list[Callable[[int], Any]] | None = None, progress_updaters: Callable[[int], Any] | list[Callable[[int], Any]] | None = None, http_client: HTTPXAsyncClient | None = None, - raise_for_status=False, + raise_for_status: bool = False, ) -> Response: - _progress_bar: ContextManager | None = None + _progress_bar: ContextManager[rich.progress.Progress] | None = None _task_label: str | None = None if isinstance(progress_bar, rich.progress.Progress): _progress_bar = progress_bar diff --git a/client_utils/utils/typer.py b/client_utils/utils/typer.py index 18fe920..ff2b0dc 100644 --- a/client_utils/utils/typer.py +++ b/client_utils/utils/typer.py @@ -8,7 +8,7 @@ import typer -def run_typer_app_as_main(app, *args, **kwargs) -> Any | None: +def run_typer_app_as_main(app: typer.Typer, *args: Any, **kwargs: Any) -> Any | None: """Run a typer app as the main function. Catch any uncaught exceptions and print them to stderr. diff --git a/pyproject.toml b/pyproject.toml index 97e99b4..7376e73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,14 @@ profile = "black" [tool.mypy] files = ["."] plugins = ["pydantic.mypy"] -warn_redundant_casts = true -warn_unreachable = true -warn_unused_configs = true -warn_unused_ignores = true +strict = true + +[[tool.mypy.overrides]] +disallow_untyped_defs = false +module = [ + "client_utils.utils.*", + "client_utils.utils.http.streaming", +] [[tool.mypy.overrides]] ignore_missing_imports = true