From 626be8a0efca20153b88fa88d93acc4818ad9fac Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Thu, 1 Aug 2024 13:19:39 -0300 Subject: [PATCH] Add option to use oauth authentication to OPDS downloader (#73) * Add option to OPDS downloader to harvest with oauth token auth. * Add a comment * Linter fixes --- src/palace_tools/cli/download_feed.py | 5 +- src/palace_tools/feeds/opds.py | 110 +++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 4 deletions(-) diff --git a/src/palace_tools/cli/download_feed.py b/src/palace_tools/cli/download_feed.py index 4cdc185..62d2017 100644 --- a/src/palace_tools/cli/download_feed.py +++ b/src/palace_tools/cli/download_feed.py @@ -91,13 +91,16 @@ def download_overdrive( def download_opds( username: str = typer.Option(None, "--username", "-u", help="Username"), password: str = typer.Option(None, "--password", "-p", help="Password"), + authentication: opds.AuthType = typer.Option( + opds.AuthType.NONE, "--auth", "-a", help="Authentication type" + ), url: str = typer.Argument(..., help="URL of feed", metavar="URL"), output_file: Path = typer.Argument( ..., help="Output file", writable=True, file_okay=True, dir_okay=False ), ) -> None: """Download OPDS 2 feed.""" - publications = opds.fetch(url, username, password) + publications = opds.fetch(url, username, password, authentication) with output_file.open("w") as file: write_json(file, publications) diff --git a/src/palace_tools/feeds/opds.py b/src/palace_tools/feeds/opds.py index 7d973f6..e80b96d 100644 --- a/src/palace_tools/feeds/opds.py +++ b/src/palace_tools/feeds/opds.py @@ -3,12 +3,103 @@ import json import math import sys +from base64 import b64encode +from collections.abc import Generator, Mapping +from enum import Enum from typing import Any, TextIO import httpx from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn +class AuthType(Enum): + BASIC = "basic" + OAUTH = "oauth" + NONE = "none" + + +class OAuthAuth(httpx.Auth): + # Implementation of OPDS auth document OAuth client credentials flow for httpx + # See: + # - https://www.python-httpx.org/advanced/authentication/#custom-authentication-schemes + # - https://drafts.opds.io/authentication-for-opds-1.0.html + + requires_response_body = True + + def __init__(self, username: str, password: str) -> None: + self.username = username + self.password = password + self.token: str | None = None + + @staticmethod + def _get_oauth_url_from_auth_document( + url: str, auth_document: Mapping[str, Any] + ) -> str: + auth_types: list[dict[str, Any]] = auth_document.get("authentication", []) + oauth_authentication = [ + tlinks + for t in auth_types + if t.get("type") == "http://opds-spec.org/auth/oauth/client_credentials" + and (tlinks := t.get("links")) is not None + ] + if not oauth_authentication: + print(f"Unable to find supported authentication type ({url})") + print(f"Auth document: {json.dumps(auth_document)}") + sys.exit(-1) + + links = oauth_authentication[0] + auth_links: list[str] = [ + lhref + for l in links + if l.get("rel") == "authenticate" and (lhref := l.get("href")) is not None + ] + if len(auth_links) != 1: + print(f"Unable to find valid authentication link ({url})") + print( + f"Found {len(auth_links)} authentication links. Auth document: {json.dumps(auth_document)}" + ) + sys.exit(-1) + return auth_links[0] + + @staticmethod + def _oauth_token_request(url: str, username: str, password: str) -> httpx.Request: + userpass = ":".join((username, password)) + token = b64encode(userpass.encode()).decode() + headers = {"Authorization": f"Basic {token}"} + return httpx.Request( + "POST", url, headers=headers, data={"grant_type": "client_credentials"} + ) + + def auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + if self.token is not None: + request.headers["Authorization"] = f"Bearer {self.token}" + response = yield request + if ( + response.status_code == 401 + and response.headers.get("Content-Type") + == "application/vnd.opds.authentication.v1.0+json" + ): + oauth_url = self._get_oauth_url_from_auth_document( + str(request.url), response.json() + ) + response = yield self._oauth_token_request( + oauth_url, self.username, self.password + ) + if response.status_code != 200: + print(f"Error: {response.status_code}") + print(response.text) + sys.exit(-1) + if (access_token := response.json().get("access_token")) is None: + print("No access token in response") + print(response.text) + sys.exit(-1) + self.token = access_token + request.headers["Authorization"] = f"Bearer {self.token}" + yield request + + def make_request(session: httpx.Client, url: str) -> dict[str, Any]: response = session.get(url) if response.status_code != 200: @@ -23,15 +114,28 @@ def write_json(file: TextIO, data: list[dict[str, Any]]) -> None: file.write(json.dumps(data, indent=4)) -def fetch(url: str, username: str | None, password: str | None) -> list[dict[str, Any]]: +def fetch( + url: str, username: str | None, password: str | None, auth_type: AuthType +) -> list[dict[str, Any]]: # Create a session to fetch the documents client = httpx.Client() - client.headers.update({"Accept": "application/opds+json", "User-Agent": "Palace"}) + client.headers.update( + { + "Accept": "application/opds+json, application/json;q=0.9, */*;q=0.1", + "User-Agent": "Palace", + } + ) client.timeout = httpx.Timeout(30.0) if username and password: - client.auth = httpx.BasicAuth(username, password) + if auth_type == AuthType.BASIC: + client.auth = httpx.BasicAuth(username, password) + elif auth_type == AuthType.OAUTH: + client.auth = OAuthAuth(username, password) + elif auth_type != AuthType.NONE: + print("Username and password are required for authentication") + sys.exit(-1) publications = []