diff --git a/src/palace/manager/api/odl/api.py b/src/palace/manager/api/odl/api.py index e8833c37e..4dda6f21f 100644 --- a/src/palace/manager/api/odl/api.py +++ b/src/palace/manager/api/odl/api.py @@ -63,7 +63,7 @@ from palace.manager.sqlalchemy.util import get_one from palace.manager.util import base64 from palace.manager.util.datetime_helpers import utc_now -from palace.manager.util.http import BadResponseException +from palace.manager.util.http import BadResponseException, RemoteIntegrationException class OPDS2WithODLApi( @@ -246,8 +246,10 @@ def get_license_status_document(self, loan: Loan) -> dict[str, Any]: hint_url=self.settings.passphrase_hint_url, ) - response = self._get(url) - if not (200 <= response.status_code < 300): + try: + response = self._get(url, allowed_response_codes=["2xx"]) + except BadResponseException as e: + response = e.response header_string = ", ".join( {f"{k}: {v}" for k, v in response.headers.items()} ) @@ -261,16 +263,17 @@ def get_license_status_document(self, loan: Loan) -> dict[str, Any]: f"status code {response.status_code}. Expected 2XX. Response headers: {header_string}. " f"Response content: {response_string}." ) - raise BadResponseException(url, "License Status Document request failed.") - + raise RemoteIntegrationException( + url, "License Status Document request failed." + ) from e try: status_doc = json.loads(response.content) except ValueError as e: - raise BadResponseException( + raise RemoteIntegrationException( url, "License Status Document was not valid JSON." - ) + ) from e if status_doc.get("status") not in self.STATUS_VALUES: - raise BadResponseException( + raise RemoteIntegrationException( url, "License Status Document had an unknown status value." ) return status_doc # type: ignore[no-any-return] @@ -958,7 +961,7 @@ def update_loan(self, loan: Loan, status_doc: dict[str, Any] | None = None) -> N # We already check that the status is valid in get_license_status_document, # but if the document came from a notification it hasn't been checked yet. if status not in self.STATUS_VALUES: - raise BadResponseException( + raise RemoteIntegrationException( str(loan.license.checkout_url), "The License Status Document had an unknown status value.", ) diff --git a/src/palace/manager/api/overdrive.py b/src/palace/manager/api/overdrive.py index 60ef9a135..34e61b5a1 100644 --- a/src/palace/manager/api/overdrive.py +++ b/src/palace/manager/api/overdrive.py @@ -562,10 +562,10 @@ def get( if status_code == 401: if exception_on_401: # This is our second try. Give up. - raise BadResponseException.from_response( + raise BadResponseException( url, "Something's wrong with the Overdrive OAuth Bearer Token!", - (status_code, headers, content), + response, ) else: # Refresh the token and try again. diff --git a/src/palace/manager/core/opds2_import.py b/src/palace/manager/core/opds2_import.py index 411cfb6d8..cf13d7427 100644 --- a/src/palace/manager/core/opds2_import.py +++ b/src/palace/manager/core/opds2_import.py @@ -10,6 +10,7 @@ import webpub_manifest_parser.opds2.ast as opds2_ast from flask_babel import lazy_gettext as _ +from requests import Response from sqlalchemy.orm import Session from uritemplate import URITemplate from webpub_manifest_parser.core import ManifestParserFactory, ManifestParserResult @@ -1157,18 +1158,16 @@ class OPDS2ImportMonitor(OPDSImportMonitor): PROTOCOL = OPDS2API.label() MEDIA_TYPE = OPDS2MediaTypesRegistry.OPDS_FEED.key, "application/json" - def _verify_media_type( - self, url: str, status_code: int, headers: Mapping[str, str], feed: bytes - ) -> None: + def _verify_media_type(self, url: str, resp: Response) -> None: # Make sure we got an OPDS feed, and not an error page that was # sent with a 200 status code. - media_type = headers.get("content-type") + media_type = resp.headers.get("content-type") if not media_type or not any(x in media_type for x in self.MEDIA_TYPE): message = "Expected {} OPDS 2.0 feed, got {}".format( self.MEDIA_TYPE, media_type ) - raise BadResponseException(url, message=message, status_code=status_code) + raise BadResponseException(url, message=message, response=resp) def _get_accept_header(self) -> str: return "{}, {};q=0.9, */*;q=0.1".format( diff --git a/src/palace/manager/core/opds_import.py b/src/palace/manager/core/opds_import.py index 00d30f221..3b93bd160 100644 --- a/src/palace/manager/core/opds_import.py +++ b/src/palace/manager/core/opds_import.py @@ -1855,17 +1855,15 @@ def identifier_needs_import( return True return False - def _verify_media_type( - self, url: str, status_code: int, headers: Mapping[str, str], feed: bytes - ) -> None: + def _verify_media_type(self, url: str, resp: Response) -> None: # Make sure we got an OPDS feed, and not an error page that was # sent with a 200 status code. - media_type = headers.get("content-type") + media_type = resp.headers.get("content-type") if not media_type or not any( x in media_type for x in (OPDSFeed.ATOM_LIKE_TYPES) ): message = "Expected Atom feed, got %s" % media_type - raise BadResponseException(url, message=message, status_code=status_code) + raise BadResponseException(url, message=message, response=resp) def follow_one_link( self, url: str, do_get: Callable[..., Response] | None = None @@ -1881,10 +1879,8 @@ def follow_one_link( get = do_get or self._get resp = get(url, {}) feed = resp.content - status_code = resp.status_code - headers = resp.headers - self._verify_media_type(url, status_code, headers, feed) + self._verify_media_type(url, resp) new_data = self.feed_contains_new_data(feed) diff --git a/src/palace/manager/util/http.py b/src/palace/manager/util/http.py index d4070bced..be31add28 100644 --- a/src/palace/manager/util/http.py +++ b/src/palace/manager/util/http.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable, Mapping, Sequence from json import JSONDecodeError -from typing import TYPE_CHECKING, Any +from typing import Any from urllib.parse import urlparse import requests @@ -28,9 +28,6 @@ ProblemDetailException, ) -if TYPE_CHECKING: - from palace.manager.sqlalchemy.model.resource import HttpResponseTuple - class RemoteIntegrationException(IntegrationException, BaseProblemDetailException): """An exception that happens when we try and fail to communicate @@ -99,55 +96,27 @@ def __init__( self, url_or_service: str, message: str, + response: Response, debug_message: str | None = None, - status_code: int | None = None, ): """Indicate that a remote integration has failed. `param url_or_service` The name of the service that failed (e.g. "Overdrive"), or the specific URL that had the problem. """ - super().__init__(url_or_service, message, debug_message) - # to be set to 500, etc. - self.status_code = status_code - - @classmethod - def from_response( - cls, url: str, message: str, response: HttpResponseTuple | Response - ) -> Self: - """Helper method to turn a `requests` Response object into - a BadResponseException. - """ - if isinstance(response, tuple): - # The response has been unrolled into a (status_code, - # headers, body) 3-tuple. - status_code, _, content_bytes = response - # The HTTP content response is a bytestring that we want to - # convert to unicode for the debug message. - if content_bytes: - content = content_bytes.decode("utf-8") - else: - content = "" - else: - status_code = response.status_code - content = response.text + if debug_message is None: + debug_message = ( + f"Status code: {response.status_code}\nContent: {response.text}" + ) - return cls( - url, - message, - status_code=status_code, - debug_message="Status code: %s\nContent: %s" - % ( - status_code, - content, - ), - ) + super().__init__(url_or_service, message, debug_message) + self.response = response @classmethod def bad_status_code(cls, url: str, response: Response) -> Self: """The response is bad because the status code is wrong.""" message = cls.BAD_STATUS_CODE_MESSAGE % response.status_code - return cls.from_response( + return cls( url, message, response, @@ -399,9 +368,9 @@ def _process_response( raise BadResponseException( url, error_message % code, - status_code=response.status_code, debug_message="Response content: %s" % cls._decode_response_content(response, url), + response=response, ) return response diff --git a/tests/fixtures/odl.py b/tests/fixtures/odl.py index aea2d9546..de9591c28 100644 --- a/tests/fixtures/odl.py +++ b/tests/fixtures/odl.py @@ -27,7 +27,7 @@ from palace.manager.sqlalchemy.model.work import Work from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.files import FilesFixture, OPDS2WithODLFilesFixture -from tests.mocks.mock import MockRequestsResponse +from tests.mocks.mock import MockHTTPClient, MockRequestsResponse from tests.mocks.odl import MockOPDS2WithODLApi @@ -44,7 +44,8 @@ def __init__( self.collection = self.create_collection(self.library) self.work = self.create_work(self.collection) self.license = self.setup_license() - self.api = MockOPDS2WithODLApi(self.db.session, self.collection) + self.mock_http = MockHTTPClient() + self.api = MockOPDS2WithODLApi(self.db.session, self.collection, self.mock_http) self.patron = db.patron() self.pool = self.license.license_pool @@ -100,12 +101,6 @@ def checkin( ) -> None: patron = patron or self.patron pool = pool or self.pool - self._checkin(self.api, patron=patron, pool=pool) - - @staticmethod - def _checkin(api: MockOPDS2WithODLApi, patron: Patron, pool: LicensePool) -> None: - """Create a function that, when evaluated, performs a checkin.""" - lsd = json.dumps( { "status": "ready", @@ -123,33 +118,20 @@ def _checkin(api: MockOPDS2WithODLApi, patron: Patron, pool: LicensePool) -> Non } ) - api.queue_response(200, content=lsd) - api.queue_response(200) - api.queue_response(200, content=returned_lsd) - api.checkin(patron, "pin", pool) + self.mock_http.queue_response(200, content=lsd) + self.mock_http.queue_response(200, content="") + self.mock_http.queue_response(200, content=returned_lsd) + self.api.checkin(patron, "pin", pool) def checkout( self, loan_url: str | None = None, patron: Patron | None = None, pool: LicensePool | None = None, - ) -> tuple[LoanInfo, Any]: + ) -> tuple[LoanInfo, Loan]: patron = patron or self.patron pool = pool or self.pool loan_url = loan_url or self.db.fresh_url() - return self._checkout( - self.api, patron=patron, pool=pool, db=self.db, loan_url=loan_url - ) - - @staticmethod - def _checkout( - api: MockOPDS2WithODLApi, - patron: Patron, - pool: LicensePool, - db: DatabaseTransactionFixture, - loan_url: str, - ) -> tuple[LoanInfo, Any]: - """Create a function that, when evaluated, performs a checkout.""" lsd = json.dumps( { @@ -163,10 +145,10 @@ def _checkout( ], } ) - api.queue_response(200, content=lsd) - loan = api.checkout(patron, "pin", pool, MagicMock()) + self.mock_http.queue_response(200, content=lsd) + loan = self.api.checkout(patron, "pin", pool, MagicMock()) loan_db = ( - db.session.query(Loan) + self.db.session.query(Loan) .filter(Loan.license_pool == pool, Loan.patron == patron) .one() ) diff --git a/tests/manager/api/odl/test_api.py b/tests/manager/api/odl/test_api.py index 63b686514..a54875221 100644 --- a/tests/manager/api/odl/test_api.py +++ b/tests/manager/api/odl/test_api.py @@ -44,7 +44,7 @@ from palace.manager.sqlalchemy.model.work import Work from palace.manager.sqlalchemy.util import create from palace.manager.util.datetime_helpers import datetime_utc, utc_now -from palace.manager.util.http import BadResponseException +from palace.manager.util.http import RemoteIntegrationException from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.odl import OPDS2WithODLApiFixture @@ -195,11 +195,11 @@ def test_get_license_status_document_success( loan, _ = opds2_with_odl_api_fixture.license.loan_to( opds2_with_odl_api_fixture.patron ) - opds2_with_odl_api_fixture.api.queue_response( + opds2_with_odl_api_fixture.mock_http.queue_response( 201, content=json.dumps(dict(status="ready")) ) opds2_with_odl_api_fixture.api.get_license_status_document(loan) - requested_url = opds2_with_odl_api_fixture.api.requests[0][0] + requested_url = opds2_with_odl_api_fixture.mock_http.requests.pop() parsed = urlparse(requested_url) assert "https" == parsed.scheme @@ -250,11 +250,11 @@ def test_get_license_status_document_success( ) loan.external_identifier = opds2_with_odl_api_fixture.db.fresh_str() - opds2_with_odl_api_fixture.api.queue_response( + opds2_with_odl_api_fixture.mock_http.queue_response( 200, content=json.dumps(dict(status="active")) ) opds2_with_odl_api_fixture.api.get_license_status_document(loan) - requested_url = opds2_with_odl_api_fixture.api.requests[1][0] + requested_url = opds2_with_odl_api_fixture.mock_http.requests.pop() assert loan.external_identifier == requested_url def test_get_license_status_document_errors( @@ -266,25 +266,27 @@ def test_get_license_status_document_errors( opds2_with_odl_api_fixture.patron ) - opds2_with_odl_api_fixture.api.queue_response(200, content="not json") + opds2_with_odl_api_fixture.mock_http.queue_response(200, content="not json") pytest.raises( - BadResponseException, + RemoteIntegrationException, opds2_with_odl_api_fixture.api.get_license_status_document, loan, ) - opds2_with_odl_api_fixture.api.queue_response( + opds2_with_odl_api_fixture.mock_http.queue_response( 200, content=json.dumps(dict(status="unknown")) ) pytest.raises( - BadResponseException, + RemoteIntegrationException, opds2_with_odl_api_fixture.api.get_license_status_document, loan, ) - opds2_with_odl_api_fixture.api.queue_response(403, content="just junk " * 100) + opds2_with_odl_api_fixture.mock_http.queue_response( + 403, content="just junk " * 100 + ) pytest.raises( - BadResponseException, + RemoteIntegrationException, opds2_with_odl_api_fixture.api.get_license_status_document, loan, ) @@ -307,10 +309,10 @@ def test_checkin_success( # The patron returns the book successfully. opds2_with_odl_api_fixture.checkin() - assert 3 == len(opds2_with_odl_api_fixture.api.requests) - assert "http://loan" in opds2_with_odl_api_fixture.api.requests[0][0] - assert "http://return" == opds2_with_odl_api_fixture.api.requests[1][0] - assert "http://loan" in opds2_with_odl_api_fixture.api.requests[2][0] + assert 3 == len(opds2_with_odl_api_fixture.mock_http.requests) + assert "http://loan" in opds2_with_odl_api_fixture.mock_http.requests[0] + assert "http://return" == opds2_with_odl_api_fixture.mock_http.requests[1] + assert "http://loan" in opds2_with_odl_api_fixture.mock_http.requests[2] # The pool's availability has increased, and the local loan has # been deleted. @@ -342,10 +344,10 @@ def test_checkin_success_with_holds_queue( # The first patron returns the book successfully. opds2_with_odl_api_fixture.checkin() - assert 3 == len(opds2_with_odl_api_fixture.api.requests) - assert "http://loan" in opds2_with_odl_api_fixture.api.requests[0][0] - assert "http://return" == opds2_with_odl_api_fixture.api.requests[1][0] - assert "http://loan" in opds2_with_odl_api_fixture.api.requests[2][0] + assert 3 == len(opds2_with_odl_api_fixture.mock_http.requests) + assert "http://loan" in opds2_with_odl_api_fixture.mock_http.requests[0] + assert "http://return" == opds2_with_odl_api_fixture.mock_http.requests[1] + assert "http://loan" in opds2_with_odl_api_fixture.mock_http.requests[2] # Now the license is reserved for the next patron. assert 0 == opds2_with_odl_api_fixture.pool.licenses_available @@ -373,12 +375,12 @@ def test_checkin_already_fulfilled( } ) - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) # Checking in the book silently does nothing. opds2_with_odl_api_fixture.api.checkin( opds2_with_odl_api_fixture.patron, "pinn", opds2_with_odl_api_fixture.pool ) - assert 1 == len(opds2_with_odl_api_fixture.api.requests) + assert 1 == len(opds2_with_odl_api_fixture.mock_http.requests) assert 6 == opds2_with_odl_api_fixture.pool.licenses_available assert 1 == db.session.query(Loan).count() @@ -409,7 +411,7 @@ def test_checkin_not_checked_out( } ) - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) pytest.raises( NotCheckedOut, opds2_with_odl_api_fixture.api.checkin, @@ -436,7 +438,7 @@ def test_checkin_cannot_return( } ) - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) # Checking in silently does nothing. opds2_with_odl_api_fixture.api.checkin( opds2_with_odl_api_fixture.patron, "pin", opds2_with_odl_api_fixture.pool @@ -456,9 +458,9 @@ def test_checkin_cannot_return( } ) - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) - opds2_with_odl_api_fixture.api.queue_response(200, content="Deleted") - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content="Deleted") + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) opds2_with_odl_api_fixture.api.checkin( opds2_with_odl_api_fixture.patron, "pin", opds2_with_odl_api_fixture.pool ) @@ -772,7 +774,7 @@ def test_checkout_cannot_loan( } ) - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) pytest.raises( CannotLoan, opds2_with_odl_api_fixture.api.checkout, @@ -792,7 +794,7 @@ def test_checkout_cannot_loan( } ) - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) pytest.raises( CannotLoan, opds2_with_odl_api_fixture.api.checkout, @@ -876,7 +878,7 @@ def test_fulfill_success( } ) - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) fulfillment = opds2_with_odl_api_fixture.api.fulfill( opds2_with_odl_api_fixture.patron, "pin", @@ -1036,7 +1038,7 @@ def test_fulfill_cannot_fulfill( } ) - opds2_with_odl_api_fixture.api.queue_response(200, content=lsd) + opds2_with_odl_api_fixture.mock_http.queue_response(200, content=lsd) pytest.raises( CannotFulfill, opds2_with_odl_api_fixture.api.fulfill, @@ -1811,6 +1813,18 @@ def test_update_loan_still_active( assert 6 == opds2_with_odl_api_fixture.pool.licenses_available assert 1 == db.session.query(Loan).count() + def test_update_loan_bad_status( + self, + db: DatabaseTransactionFixture, + opds2_with_odl_api_fixture: OPDS2WithODLApiFixture, + ) -> None: + status_doc = { + "status": "foo", + } + + with pytest.raises(RemoteIntegrationException, match="unknown status value"): + opds2_with_odl_api_fixture.api.update_loan(MagicMock(), status_doc) + def test_update_loan_removes_loan( self, db: DatabaseTransactionFixture, diff --git a/tests/manager/core/test_opds2_import.py b/tests/manager/core/test_opds2_import.py index bd7bb43e9..1dee74486 100644 --- a/tests/manager/core/test_opds2_import.py +++ b/tests/manager/core/test_opds2_import.py @@ -1,6 +1,7 @@ import datetime import json from collections.abc import Generator +from contextlib import nullcontext from unittest.mock import MagicMock, patch import pytest @@ -11,7 +12,12 @@ from palace.manager.api.circulation import CirculationAPI, FulfillmentInfo from palace.manager.api.circulation_exceptions import CannotFulfill -from palace.manager.core.opds2_import import OPDS2API, OPDS2Importer, RWPMManifestParser +from palace.manager.core.opds2_import import ( + OPDS2API, + OPDS2Importer, + OPDS2ImportMonitor, + RWPMManifestParser, +) from palace.manager.sqlalchemy.constants import ( EditionConstants, IdentifierType, @@ -29,8 +35,10 @@ from palace.manager.sqlalchemy.model.patron import Loan from palace.manager.sqlalchemy.model.work import Work from palace.manager.util.datetime_helpers import utc_now +from palace.manager.util.http import BadResponseException from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.files import OPDS2FilesFixture +from tests.mocks.mock import MockRequestsResponse class OPDS2Test: @@ -893,3 +901,43 @@ def test_get_authentication_token_bad_response( OPDS2API.get_authentication_token( opds2_api_fixture.patron, opds2_api_fixture.data_source, "" ) + + +class TestOPDS2ImportMonitor: + @pytest.mark.parametrize( + "content_type,exception", + [ + ("application/json", False), + ("application/opds+json", False), + ("application/xml", True), + ("foo/xyz", True), + ], + ) + def test__verify_media_type( + self, db: DatabaseTransactionFixture, content_type: str, exception: bool + ) -> None: + collection = db.collection( + protocol=OPDS2API.label(), + data_source_name="test", + external_account_id="http://test.com", + ) + monitor = OPDS2ImportMonitor( + db.session, + collection, + OPDS2Importer, + parser=RWPMManifestParser(OPDS2FeedParserFactory()), + ) + + ctx_manager = ( + nullcontext() + if not exception + else pytest.raises( + BadResponseException, match="Bad response from http://test.com" + ) + ) + + mock_response = MockRequestsResponse( + status_code=200, headers={"Content-Type": content_type} + ) + with ctx_manager: + monitor._verify_media_type("http://test.com", mock_response) diff --git a/tests/manager/util/test_http.py b/tests/manager/util/test_http.py index 3fdb3aa7e..e5309268f 100644 --- a/tests/manager/util/test_http.py +++ b/tests/manager/util/test_http.py @@ -394,14 +394,14 @@ def test_with_debug_message(self): class TestBadResponseException: - def test_from_response(self): + def test__init__(self): response = MockRequestsResponse(102, content="nonsense") - exc = BadResponseException.from_response( + exc = BadResponseException( "http://url/", "Terrible response, just terrible", response ) - # the status code gets set on the exception - assert exc.status_code == 102 + # the response gets set on the exception + assert exc.response is response # Turn the exception into a problem detail document, and it's full # of useful information. @@ -418,7 +418,7 @@ def test_from_response(self): ) assert problem_detail.status_code == 502 - def test_bad_status_code(object): + def test_bad_status_code(self): response = MockRequestsResponse(500, content="Internal Server Error!") exc = BadResponseException.bad_status_code("http://url/", response) doc = exc.problem_detail @@ -434,11 +434,12 @@ def test_bad_status_code(object): ) def test_problem_detail(self): + response = MockRequestsResponse(401, content="You are not authorized!") exception = BadResponseException( "http://url/", "What even is this", debug_message="some debug info", - status_code=401, + response=response, ) document = exception.problem_detail assert 502 == document.status_code @@ -451,7 +452,7 @@ def test_problem_detail(self): "Bad response from http://url/: What even is this\n\nsome debug info" == document.debug_message ) - assert exception.status_code == 401 + assert exception.response is response class TestRequestTimedOut: diff --git a/tests/mocks/mock.py b/tests/mocks/mock.py index 239cdb6ec..f0fa9af57 100644 --- a/tests/mocks/mock.py +++ b/tests/mocks/mock.py @@ -1,6 +1,9 @@ import json import logging -from typing import Any +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any, NamedTuple +from unittest.mock import patch from requests import Request, Response @@ -13,6 +16,7 @@ from palace.manager.core.opds_import import OPDSAPI from palace.manager.sqlalchemy.model.datasource import DataSource from palace.manager.sqlalchemy.model.resource import HttpResponseTuple +from palace.manager.util.http import HTTP def _normalize_level(level): @@ -216,10 +220,18 @@ def process_batch(self, batch): return [] +class Args(NamedTuple): + """A simple container for positional and keyword arguments.""" + + args: tuple[Any, ...] + kwargs: dict[str, Any] + + class MockHTTPClient: def __init__(self) -> None: self.responses: list[Response] = [] self.requests: list[str] = [] + self.requests_args: list[Args] = [] def queue_response( self, @@ -235,9 +247,18 @@ def queue_response( self.responses.append(MockRequestsResponse(response_code, headers, content)) + def _get(self, *args: Any, **kwargs: Any) -> Response: + return self.responses.pop(0) + def do_get(self, url: str, *args: Any, **kwargs: Any) -> Response: self.requests.append(url) - return self.responses.pop(0) + self.requests_args.append(Args(args, kwargs)) + return HTTP._request_with_timeout(url, self._get, *args, **kwargs) + + @contextmanager + def patch(self) -> Generator[None, None, None]: + with patch.object(HTTP, "get_with_timeout", self.do_get): + yield class MockRepresentationHTTPClient: diff --git a/tests/mocks/odl.py b/tests/mocks/odl.py index 2f21a21e2..bc7e7b346 100644 --- a/tests/mocks/odl.py +++ b/tests/mocks/odl.py @@ -10,8 +10,7 @@ from palace.manager.api.odl.settings import OPDS2AuthType from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.util.datetime_helpers import utc_now -from palace.manager.util.http import HTTP -from tests.mocks.mock import MockRequestsResponse +from tests.mocks.mock import MockHTTPClient class MockOPDS2WithODLApi(OPDS2WithODLApi): @@ -19,12 +18,11 @@ def __init__( self, _db: Session, collection: Collection, + mock_http_client: MockHTTPClient, ) -> None: super().__init__(_db, collection) - self.responses: list[MockRequestsResponse] = [] - self.requests: list[ - tuple[str, Mapping[str, str] | None, Mapping[str, Any]] - ] = [] + + self.mock_http_client = mock_http_client self.mock_auth_type = self.settings.auth_type self.refresh_token_calls = 0 self.refresh_token_timedelta = timedelta(minutes=30) @@ -39,16 +37,6 @@ def _refresh_token(self) -> None: token="new_token", expires=utc_now() + self.refresh_token_timedelta ) - def queue_response( - self, - status_code: int, - headers: dict[str, str] | None = None, - content: str | None = None, - ): - if headers is None: - headers = {} - self.responses.insert(0, MockRequestsResponse(status_code, headers, content)) - def _url_for(self, *args: Any, **kwargs: Any) -> str: del kwargs["_external"] return "http://{}?{}".format( @@ -59,6 +47,4 @@ def _url_for(self, *args: Any, **kwargs: Any) -> str: def _get( self, url: str, headers: Mapping[str, str] | None = None, **kwargs: Any ) -> Response: - self.requests.append((url, headers, kwargs)) - response = self.responses.pop() - return HTTP._process_response(url, response) + return self.mock_http_client.do_get(url, headers=headers, **kwargs)