Skip to content

Commit

Permalink
Add type information and require types
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathangreen committed Mar 6, 2024
1 parent 7675a85 commit 70657af
Show file tree
Hide file tree
Showing 14 changed files with 62 additions and 46 deletions.
4 changes: 2 additions & 2 deletions client_utils/cli/download_lcp_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ async def process_command(
manifest_file: Path | str | int,
username: str,
password: str | None,
manifest_member_name="manifest.json",
pretty_print=False,
manifest_member_name: str = "manifest.json",
pretty_print: bool = False,
) -> None:
client_headers = {"User-Agent": "Palace"}
token: BaseAuthorizationToken = BasicAuthToken.from_username_and_password(
Expand Down
4 changes: 2 additions & 2 deletions client_utils/cli/palace_terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def init_handlers(self) -> None:
)


class MediaPlayerUi(App):
class MediaPlayerUi(App[None]):
TITLE = "PALACE - Terminal Edition"
BINDINGS = [
("p", "play", "Play Media"),
Expand Down Expand Up @@ -601,7 +601,7 @@ def update_player_info(self) -> None:
log(traceback.format_exc())

@staticmethod
def set_playing_row(table: DataTable, key: str) -> None:
def set_playing_row(table: DataTable[str], key: str) -> None:
for row_key in table.rows.keys():
if row_key.value == key:
updated_str = "▶️"
Expand Down
7 changes: 4 additions & 3 deletions client_utils/cli/patron_bookshelf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typer

from client_utils.constants import DEFAULT_REGISTRY_URL
from client_utils.models.api.opds2 import OPDS2Feed
from client_utils.models.internal.bookshelf import print_bookshelf_summary
from client_utils.roles.patron import authenticate
from client_utils.utils.http.async_client import HTTPXAsyncClient
Expand All @@ -14,7 +15,7 @@
app = typer.Typer(rich_markup_mode="rich")


def main():
def main() -> None:
run_typer_app_as_main(app)


Expand Down Expand Up @@ -78,7 +79,7 @@ def patron_bookshelf(
help="Output bookshelf as JSON.",
rich_help_panel="Output",
),
):
) -> None:
bookshelf = asyncio.run(
fetch_bookshelf(
username=username,
Expand All @@ -105,7 +106,7 @@ async def fetch_bookshelf(
opds_server: str | None = None,
auth_doc_url: str | None = None,
allow_hidden_libraries: bool = False,
):
) -> OPDS2Feed:
async with HTTPXAsyncClient() as client:
patron = await authenticate(
username=username,
Expand Down
2 changes: 1 addition & 1 deletion client_utils/cli/summarize_rwpm_audio_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def format_delta(delta: int | float, delta_suffix: str | None = None) -> str:
def text_with_time_delta(
text: str,
delta_secs: int | float,
delta_label="duration",
delta_label: str = "duration",
delta_suffix: str | None = None,
second_delta: int | float | None = None,
second_delta_suffix: str | None = None,
Expand Down
3 changes: 2 additions & 1 deletion client_utils/models/api/authentication_document.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Any

from client_utils.constants import PATRON_BOOKSHELF_REL, PATRON_PROFILE_REL
from client_utils.models.api.opds2 import OPDS2Link, match_links
Expand Down Expand Up @@ -61,7 +62,7 @@ class AuthenticationDocument(ApiBaseModel):
authentication: list[AuthenticationMechanism]
features: Features
links: list[OPDS2Link]
announcements: list
announcements: list[Any]
service_description: str
public_key: PublicKey
# color_scheme: str
Expand Down
16 changes: 8 additions & 8 deletions client_utils/models/api/opds2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,26 @@ class OPDS2Link(ApiBaseModel):
templated: bool = False

@property
def is_acquisition(self):
def is_acquisition(self) -> bool:
"""Is this an acquisition link?"""
return any(
rel in [OPDS_ACQ_STANDARD_REL, OPDS_ACQ_OPEN_ACCESS_REL]
for rel in ensure_list(self.rel)
)

@property
def has_indirect_acquisition(self):
def has_indirect_acquisition(self) -> bool:
"""Does link have one or more indirect acquisition links?"""
return bool(self.indirect_acquisition_links)

@property
def indirect_acquisition_links(self) -> Sequence[Mapping[str, Any]]:
"""Indirect acquisition link, if any."""
return vars(self).get("properties", {}).get("indirectAcquisition", [])
return vars(self).get("properties", {}).get("indirectAcquisition", []) # type: ignore[no-any-return]


class OPDS2(ApiBaseModel):
catalogs: list = []
catalogs: list[Any] = []
links: list[OPDS2Link] = []
metadata: Mapping[str, Any]

Expand Down Expand Up @@ -116,14 +116,14 @@ class Publication(ApiBaseModel):
images: list[Image] = []

@property
def acquisition_links(self):
def acquisition_links(self) -> list[OPDS2Link]:
return match_links(
self.links,
lambda link: link.rel in [OPDS_ACQ_STANDARD_REL, OPDS_ACQ_OPEN_ACCESS_REL],
)

@property
def revoke_links(self):
def revoke_links(self) -> list[OPDS2Link]:
return match_links(self.links, lambda link: link.rel == OPDS_REVOKE_REL)

@property
Expand All @@ -138,10 +138,10 @@ class FeedMetadata(ApiBaseModel):

class OPDS2Feed(ApiBaseModel):
publications: list[Publication] = []
catalogs: list = []
catalogs: list[Any] = []
links: list[OPDS2Link] = []
metadata: Mapping[str, Any]
facets: list
facets: list[Any]


L = TypeVar("L", bound=Mapping[str, str] | OPDS2Link)
Expand Down
4 changes: 2 additions & 2 deletions client_utils/models/api/rwpm_audiobook.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def toc_in_playback_order(self) -> Generator[ToCEntry, None, None]:
yield from child.toc_in_playback_order()

@classmethod
def from_track(cls, track: AudioTrack, default_title="Track") -> Self:
def from_track(cls, track: AudioTrack, default_title: str = "Track") -> Self:
"""Create a ToCEntry from an AudioTrack."""
toc_href = f"{track.href}#t=0"
toc_title = track.title or default_title
Expand Down Expand Up @@ -89,7 +89,7 @@ def validate_model(self) -> Self:
return self

@cached_property
def effective_toc(self):
def effective_toc(self) -> Sequence[ToCEntry]:
return self.toc or [
ToCEntry.from_track(track=track, default_title=f"Track {n}")
for n, track in enumerate(self.reading_order, start=1)
Expand Down
2 changes: 1 addition & 1 deletion client_utils/models/internal/rwpm_audio/audio_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AudioSegment:
duration: int = field(init=False)
actual_duration: float = field(init=False)

def __post_init__(self):
def __post_init__(self) -> None:
self.duration = self.end - self.start
self.actual_duration = self.end_actual - self.start

Expand Down
10 changes: 5 additions & 5 deletions client_utils/models/internal/rwpm_audio/audiobook.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class EnhancedToCEntry(ToCEntry):
actual_duration: float = 0.0

@cached_property
def total_duration(self):
def total_duration(self) -> int:
"""The duration (in seconds) of this ToCEntry and its children."""
return sum(toc.duration for toc in self.enhanced_toc_in_playback_order)

Expand Down Expand Up @@ -64,7 +64,7 @@ def generate_enhanced_toc(
self,
toc: Sequence[ToCEntry] | None,
depth: int = 0,
):
) -> Sequence[EnhancedToCEntry]:
"""Recursively generate enhanced ToC entries."""
return (
[
Expand Down Expand Up @@ -124,12 +124,12 @@ def pre_toc_unplayed_audio_segments(self) -> Sequence[AudioSegment]:
).audio_segments

@cached_property
def toc_total_duration(self):
def toc_total_duration(self) -> int:
"""The duration (in seconds) of this ToCEntry and its children."""
return sum(toc.duration for toc in self.enhanced_toc_in_playback_order)

@cached_property
def toc_actual_total_duration(self):
def toc_actual_total_duration(self) -> float:
"""The duration (in seconds) of this ToCEntry and its children."""
return sum(toc.actual_duration for toc in self.enhanced_toc_in_playback_order)

Expand All @@ -141,6 +141,6 @@ def from_manifest_file(cls, filepath: Path | str) -> Self:
# Try to load the track
track_file = directory_path / track.href
if track_file.is_file():
track.actual_duration = MP3(track_file).info.length # type: ignore[attr-defined]
track.actual_duration = MP3(track_file).info.length

return cls(manifest=manifest)
10 changes: 7 additions & 3 deletions client_utils/roles/patron.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class PatronAuthorization:
class AuthenticatedPatron(PatronAuthorization):
authentication_document: AuthenticationDocument

async def patron_profile_document(self, http_client=None):
async def patron_profile_document(
self, http_client: HTTPXAsyncClient | None = None
) -> PatronProfileDocument:
[patron_profile_link] = self.authentication_document.patron_profile_links
headers = dict(self.token.as_http_headers)
async with HTTPXAsyncClient.with_existing_client(
Expand All @@ -52,7 +54,9 @@ async def patron_profile_document(self, http_client=None):
).json()
return PatronProfileDocument.model_validate(profile)

async def patron_bookshelf(self, http_client=None):
async def patron_bookshelf(
self, http_client: HTTPXAsyncClient | None = None
) -> OPDS2Feed:
[patron_bookshelf_link] = self.authentication_document.patron_bookshelf_links
headers = dict(self.token.as_http_headers) | {"Accept": OPDS_2_TYPE}
async with HTTPXAsyncClient.with_existing_client(
Expand All @@ -74,7 +78,7 @@ async def authenticate(
opds_server: str | None = None,
allow_hidden_libraries: bool = False,
http_client: HTTPXAsyncClient | None = None,
):
) -> AuthenticatedPatron:
"""Login as a patron."""
async with HTTPXAsyncClient.with_existing_client(
existing_client=http_client
Expand Down
17 changes: 10 additions & 7 deletions client_utils/utils/http/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
from contextlib import nullcontext
from typing import Any

from httpx import AsyncClient, Response

Expand All @@ -14,22 +15,24 @@


class HTTPXAsyncClient(AsyncClient):
def __init__(self, user_agent: str = DEFAULT_USER_AGENT, **kwargs):
def __init__(self, user_agent: str = DEFAULT_USER_AGENT, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.user_agent = user_agent

async def request(self, method, url, *, headers=None, **kwargs) -> Response:
async def request(self, method: str, url: str, *, headers: dict[str, str] | None = None, **kwargs: Any) -> Response: # type: ignore[override]
headers = {"User-Agent": self.user_agent} | (headers or {})
return await super().request(method, url, headers=headers, **kwargs)

async def post(self, *args, **kwargs) -> Response:
async def post(self, *args: Any, **kwargs: Any) -> Response:
return await self.request("POST", *args, **kwargs)

async def get(self, *args, **kwargs) -> Response:
async def get(self, *args: Any, **kwargs: Any) -> Response:
return await self.request("GET", *args, **kwargs)

@classmethod
def with_existing_client(cls, *args, existing_client: Self | None = None, **kwargs):
def with_existing_client(
cls, *args: Any, existing_client: Self | None = None, **kwargs: Any
) -> Self:
"""Return an instance of our self.
:param existing_client: A client to use instead of creating a new one.
Expand All @@ -39,12 +42,12 @@ def with_existing_client(cls, *args, existing_client: Self | None = None, **kwar
If not, a new one will be instantiated, using the provided arguments.
"""
if existing_client:
return nullcontext(enter_result=existing_client)
return nullcontext(enter_result=existing_client) # type: ignore[return-value]
else:
return cls(*args, **kwargs)


def validate_response(response: Response, raise_for_status=True):
def validate_response(response: Response, raise_for_status: bool = True) -> Response:
if raise_for_status:
response.raise_for_status()
return response
15 changes: 9 additions & 6 deletions client_utils/utils/http/streaming.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Any, BinaryIO, ContextManager
from typing import Any, BinaryIO, ContextManager, TypeVar

import rich.progress
from httpx import Response
Expand All @@ -18,7 +18,10 @@ def default_progress_bar() -> rich.progress.Progress:
)


def _to_list(value) -> list:
T = TypeVar("T")


def _to_list(value: Sequence[T] | T | None) -> list[T]:
"""Ensure that we end up with our own copy as a list.
:param value: The value from which we'll create our list.
Expand All @@ -45,8 +48,8 @@ async def streaming_fetch(
| Sequence[Callable[[int], Any]]
| None = None,
http_client: HTTPXAsyncClient | None = None,
raise_for_status=False,
):
raise_for_status: bool = False,
) -> Response:
async with HTTPXAsyncClient.with_existing_client(http_client) as client:
async with client.stream("GET", url=url) as response:
if raise_for_status:
Expand Down Expand Up @@ -78,9 +81,9 @@ async def streaming_fetch_with_progress(
total_setters: Callable[[int], Any] | list[Callable[[int], Any]] | None = None,
progress_updaters: Callable[[int], Any] | list[Callable[[int], Any]] | None = None,
http_client: HTTPXAsyncClient | None = None,
raise_for_status=False,
raise_for_status: bool = False,
) -> Response:
_progress_bar: ContextManager | None = None
_progress_bar: ContextManager[rich.progress.Progress] | None = None
_task_label: str | None = None
if isinstance(progress_bar, rich.progress.Progress):
_progress_bar = progress_bar
Expand Down
2 changes: 1 addition & 1 deletion client_utils/utils/typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import typer


def run_typer_app_as_main(app, *args, **kwargs) -> Any | None:
def run_typer_app_as_main(app: typer.Typer, *args: Any, **kwargs: Any) -> Any | None:
"""Run a typer app as the main function.
Catch any uncaught exceptions and print them to stderr.
Expand Down
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ profile = "black"
[tool.mypy]
files = ["."]
plugins = ["pydantic.mypy"]
warn_redundant_casts = true
warn_unreachable = true
warn_unused_configs = true
warn_unused_ignores = true
strict = true

[[tool.mypy.overrides]]
disallow_untyped_defs = false
module = [
"client_utils.utils.*",
"client_utils.utils.http.streaming",
]

[[tool.mypy.overrides]]
ignore_missing_imports = true
Expand Down

0 comments on commit 70657af

Please sign in to comment.