Skip to content

Commit

Permalink
Refactor BadResponseException (PP-1641) (#2034)
Browse files Browse the repository at this point in the history
* Refactor BadResponseException

Refactor the exception, so it has a reference to the response that
can be used to deal with the exception.

* Add a couple more test

* Code review feedback
  • Loading branch information
jonathangreen authored Sep 6, 2024
1 parent c7489f7 commit f8440a8
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 153 deletions.
21 changes: 12 additions & 9 deletions src/palace/manager/api/odl/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from palace.manager.sqlalchemy.util import get_one
from palace.manager.util import base64
from palace.manager.util.datetime_helpers import utc_now
from palace.manager.util.http import BadResponseException
from palace.manager.util.http import BadResponseException, RemoteIntegrationException


class OPDS2WithODLApi(
Expand Down Expand Up @@ -246,8 +246,10 @@ def get_license_status_document(self, loan: Loan) -> dict[str, Any]:
hint_url=self.settings.passphrase_hint_url,
)

response = self._get(url)
if not (200 <= response.status_code < 300):
try:
response = self._get(url, allowed_response_codes=["2xx"])
except BadResponseException as e:
response = e.response
header_string = ", ".join(
{f"{k}: {v}" for k, v in response.headers.items()}
)
Expand All @@ -261,16 +263,17 @@ def get_license_status_document(self, loan: Loan) -> dict[str, Any]:
f"status code {response.status_code}. Expected 2XX. Response headers: {header_string}. "
f"Response content: {response_string}."
)
raise BadResponseException(url, "License Status Document request failed.")

raise RemoteIntegrationException(
url, "License Status Document request failed."
) from e
try:
status_doc = json.loads(response.content)
except ValueError as e:
raise BadResponseException(
raise RemoteIntegrationException(
url, "License Status Document was not valid JSON."
)
) from e
if status_doc.get("status") not in self.STATUS_VALUES:
raise BadResponseException(
raise RemoteIntegrationException(
url, "License Status Document had an unknown status value."
)
return status_doc # type: ignore[no-any-return]
Expand Down Expand Up @@ -958,7 +961,7 @@ def update_loan(self, loan: Loan, status_doc: dict[str, Any] | None = None) -> N
# We already check that the status is valid in get_license_status_document,
# but if the document came from a notification it hasn't been checked yet.
if status not in self.STATUS_VALUES:
raise BadResponseException(
raise RemoteIntegrationException(
str(loan.license.checkout_url),
"The License Status Document had an unknown status value.",
)
Expand Down
4 changes: 2 additions & 2 deletions src/palace/manager/api/overdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,10 +562,10 @@ def get(
if status_code == 401:
if exception_on_401:
# This is our second try. Give up.
raise BadResponseException.from_response(
raise BadResponseException(
url,
"Something's wrong with the Overdrive OAuth Bearer Token!",
(status_code, headers, content),
response,
)
else:
# Refresh the token and try again.
Expand Down
9 changes: 4 additions & 5 deletions src/palace/manager/core/opds2_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import webpub_manifest_parser.opds2.ast as opds2_ast
from flask_babel import lazy_gettext as _
from requests import Response
from sqlalchemy.orm import Session
from uritemplate import URITemplate
from webpub_manifest_parser.core import ManifestParserFactory, ManifestParserResult
Expand Down Expand Up @@ -1157,18 +1158,16 @@ class OPDS2ImportMonitor(OPDSImportMonitor):
PROTOCOL = OPDS2API.label()
MEDIA_TYPE = OPDS2MediaTypesRegistry.OPDS_FEED.key, "application/json"

def _verify_media_type(
self, url: str, status_code: int, headers: Mapping[str, str], feed: bytes
) -> None:
def _verify_media_type(self, url: str, resp: Response) -> None:
# Make sure we got an OPDS feed, and not an error page that was
# sent with a 200 status code.
media_type = headers.get("content-type")
media_type = resp.headers.get("content-type")
if not media_type or not any(x in media_type for x in self.MEDIA_TYPE):
message = "Expected {} OPDS 2.0 feed, got {}".format(
self.MEDIA_TYPE, media_type
)

raise BadResponseException(url, message=message, status_code=status_code)
raise BadResponseException(url, message=message, response=resp)

def _get_accept_header(self) -> str:
return "{}, {};q=0.9, */*;q=0.1".format(
Expand Down
12 changes: 4 additions & 8 deletions src/palace/manager/core/opds_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,17 +1855,15 @@ def identifier_needs_import(
return True
return False

def _verify_media_type(
self, url: str, status_code: int, headers: Mapping[str, str], feed: bytes
) -> None:
def _verify_media_type(self, url: str, resp: Response) -> None:
# Make sure we got an OPDS feed, and not an error page that was
# sent with a 200 status code.
media_type = headers.get("content-type")
media_type = resp.headers.get("content-type")
if not media_type or not any(
x in media_type for x in (OPDSFeed.ATOM_LIKE_TYPES)
):
message = "Expected Atom feed, got %s" % media_type
raise BadResponseException(url, message=message, status_code=status_code)
raise BadResponseException(url, message=message, response=resp)

def follow_one_link(
self, url: str, do_get: Callable[..., Response] | None = None
Expand All @@ -1881,10 +1879,8 @@ def follow_one_link(
get = do_get or self._get
resp = get(url, {})
feed = resp.content
status_code = resp.status_code
headers = resp.headers

self._verify_media_type(url, status_code, headers, feed)
self._verify_media_type(url, resp)

new_data = self.feed_contains_new_data(feed)

Expand Down
51 changes: 10 additions & 41 deletions src/palace/manager/util/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from collections.abc import Callable, Mapping, Sequence
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any
from typing import Any
from urllib.parse import urlparse

import requests
Expand All @@ -28,9 +28,6 @@
ProblemDetailException,
)

if TYPE_CHECKING:
from palace.manager.sqlalchemy.model.resource import HttpResponseTuple


class RemoteIntegrationException(IntegrationException, BaseProblemDetailException):
"""An exception that happens when we try and fail to communicate
Expand Down Expand Up @@ -99,55 +96,27 @@ def __init__(
self,
url_or_service: str,
message: str,
response: Response,
debug_message: str | None = None,
status_code: int | None = None,
):
"""Indicate that a remote integration has failed.
`param url_or_service` The name of the service that failed
(e.g. "Overdrive"), or the specific URL that had the problem.
"""
super().__init__(url_or_service, message, debug_message)
# to be set to 500, etc.
self.status_code = status_code

@classmethod
def from_response(
cls, url: str, message: str, response: HttpResponseTuple | Response
) -> Self:
"""Helper method to turn a `requests` Response object into
a BadResponseException.
"""
if isinstance(response, tuple):
# The response has been unrolled into a (status_code,
# headers, body) 3-tuple.
status_code, _, content_bytes = response
# The HTTP content response is a bytestring that we want to
# convert to unicode for the debug message.
if content_bytes:
content = content_bytes.decode("utf-8")
else:
content = ""
else:
status_code = response.status_code
content = response.text
if debug_message is None:
debug_message = (
f"Status code: {response.status_code}\nContent: {response.text}"
)

return cls(
url,
message,
status_code=status_code,
debug_message="Status code: %s\nContent: %s"
% (
status_code,
content,
),
)
super().__init__(url_or_service, message, debug_message)
self.response = response

@classmethod
def bad_status_code(cls, url: str, response: Response) -> Self:
"""The response is bad because the status code is wrong."""
message = cls.BAD_STATUS_CODE_MESSAGE % response.status_code
return cls.from_response(
return cls(
url,
message,
response,
Expand Down Expand Up @@ -399,9 +368,9 @@ def _process_response(
raise BadResponseException(
url,
error_message % code,
status_code=response.status_code,
debug_message="Response content: %s"
% cls._decode_response_content(response, url),
response=response,
)
return response

Expand Down
40 changes: 11 additions & 29 deletions tests/fixtures/odl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from palace.manager.sqlalchemy.model.work import Work
from tests.fixtures.database import DatabaseTransactionFixture
from tests.fixtures.files import FilesFixture, OPDS2WithODLFilesFixture
from tests.mocks.mock import MockRequestsResponse
from tests.mocks.mock import MockHTTPClient, MockRequestsResponse
from tests.mocks.odl import MockOPDS2WithODLApi


Expand All @@ -44,7 +44,8 @@ def __init__(
self.collection = self.create_collection(self.library)
self.work = self.create_work(self.collection)
self.license = self.setup_license()
self.api = MockOPDS2WithODLApi(self.db.session, self.collection)
self.mock_http = MockHTTPClient()
self.api = MockOPDS2WithODLApi(self.db.session, self.collection, self.mock_http)
self.patron = db.patron()
self.pool = self.license.license_pool

Expand Down Expand Up @@ -100,12 +101,6 @@ def checkin(
) -> None:
patron = patron or self.patron
pool = pool or self.pool
self._checkin(self.api, patron=patron, pool=pool)

@staticmethod
def _checkin(api: MockOPDS2WithODLApi, patron: Patron, pool: LicensePool) -> None:
"""Create a function that, when evaluated, performs a checkin."""

lsd = json.dumps(
{
"status": "ready",
Expand All @@ -123,33 +118,20 @@ def _checkin(api: MockOPDS2WithODLApi, patron: Patron, pool: LicensePool) -> Non
}
)

api.queue_response(200, content=lsd)
api.queue_response(200)
api.queue_response(200, content=returned_lsd)
api.checkin(patron, "pin", pool)
self.mock_http.queue_response(200, content=lsd)
self.mock_http.queue_response(200, content="")
self.mock_http.queue_response(200, content=returned_lsd)
self.api.checkin(patron, "pin", pool)

def checkout(
self,
loan_url: str | None = None,
patron: Patron | None = None,
pool: LicensePool | None = None,
) -> tuple[LoanInfo, Any]:
) -> tuple[LoanInfo, Loan]:
patron = patron or self.patron
pool = pool or self.pool
loan_url = loan_url or self.db.fresh_url()
return self._checkout(
self.api, patron=patron, pool=pool, db=self.db, loan_url=loan_url
)

@staticmethod
def _checkout(
api: MockOPDS2WithODLApi,
patron: Patron,
pool: LicensePool,
db: DatabaseTransactionFixture,
loan_url: str,
) -> tuple[LoanInfo, Any]:
"""Create a function that, when evaluated, performs a checkout."""

lsd = json.dumps(
{
Expand All @@ -163,10 +145,10 @@ def _checkout(
],
}
)
api.queue_response(200, content=lsd)
loan = api.checkout(patron, "pin", pool, MagicMock())
self.mock_http.queue_response(200, content=lsd)
loan = self.api.checkout(patron, "pin", pool, MagicMock())
loan_db = (
db.session.query(Loan)
self.db.session.query(Loan)
.filter(Loan.license_pool == pool, Loan.patron == patron)
.one()
)
Expand Down
Loading

0 comments on commit f8440a8

Please sign in to comment.