From d7ab9fbe7a58ea13225644ed0f1ed9647c5fa2c0 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 31 Jul 2024 10:14:39 -0300 Subject: [PATCH] Refactor OPDS classes to use HTTP.get_with_timeout (PP-1485) (#1956) * Refactor OPDS classes to use HTTP.get_with_timeout functions instead of resource function. --- src/palace/manager/api/odl/importer.py | 21 ++-- .../manager/api/opds_for_distributors.py | 4 +- src/palace/manager/core/opds2_import.py | 11 +- src/palace/manager/core/opds_import.py | 45 ++----- tests/fixtures/odl.py | 7 +- tests/manager/api/controller/test_loan.py | 6 +- tests/manager/api/metadata/test_novelist.py | 4 +- tests/manager/api/odl/test_importer.py | 46 ++++---- tests/manager/api/test_enki.py | 2 +- .../manager/api/test_opds_for_distributors.py | 5 +- tests/manager/api/test_overdrive.py | 29 +++-- tests/manager/core/test_opds_import.py | 27 +---- .../manager/sqlalchemy/model/test_resource.py | 16 +-- tests/mocks/mock.py | 111 +++++++++++------- 14 files changed, 160 insertions(+), 174 deletions(-) diff --git a/src/palace/manager/api/odl/importer.py b/src/palace/manager/api/odl/importer.py index 1a538bfccc..e915ad0aaa 100644 --- a/src/palace/manager/api/odl/importer.py +++ b/src/palace/manager/api/odl/importer.py @@ -1,11 +1,11 @@ from __future__ import annotations import datetime -import json from collections.abc import Callable from typing import TYPE_CHECKING, Any import dateutil +from requests import Response from sqlalchemy.orm import Session from webpub_manifest_parser.odl import ODLFeedParserFactory from webpub_manifest_parser.opds2.registry import OPDS2LinkRelationsRegistry @@ -27,9 +27,9 @@ LicenseStatus, RightsStatus, ) -from palace.manager.sqlalchemy.model.resource import HttpResponseTuple from palace.manager.util import first_or_default from palace.manager.util.datetime_helpers import to_utc +from palace.manager.util.http import HTTP if TYPE_CHECKING: from webpub_manifest_parser.core.ast import Metadata @@ -63,7 +63,7 @@ def __init__( collection: Collection, parser: RWPMManifestParser | None = None, data_source_name: str | None = None, - http_get: Callable[..., HttpResponseTuple] | None = None, + http_get: Callable[..., Response] | None = None, ): """Initialize a new instance of OPDS2WithODLImporter class. @@ -90,9 +90,10 @@ def __init__( collection, parser if parser else RWPMManifestParser(ODLFeedParserFactory()), data_source_name, - http_get, ) + self.http_get = http_get or HTTP.get_with_timeout + def _extract_publication_metadata( self, feed: OPDS2Feed, @@ -229,16 +230,16 @@ def _extract_publication_metadata( @classmethod def fetch_license_info( - cls, document_link: str, do_get: Callable[..., HttpResponseTuple] + cls, document_link: str, do_get: Callable[..., Response] ) -> dict[str, Any] | None: - status_code, _, response = do_get(document_link, headers={}) - if status_code in (200, 201): - license_info_document = json.loads(response) + resp = do_get(document_link, headers={}) + if resp.status_code in (200, 201): + license_info_document = resp.json() return license_info_document # type: ignore[no-any-return] else: cls.logger().warning( f"License Info Document is not available. " - f"Status link {document_link} failed with {status_code} code." + f"Status link {document_link} failed with {resp.status_code} code." ) return None @@ -337,7 +338,7 @@ def get_license_data( feed_license_identifier: str | None, feed_license_expires: datetime.datetime | None, feed_concurrency: int | None, - do_get: Callable[..., HttpResponseTuple], + do_get: Callable[..., Response], ) -> LicenseData | None: license_info_document = cls.fetch_license_info(license_info_link, do_get) diff --git a/src/palace/manager/api/opds_for_distributors.py b/src/palace/manager/api/opds_for_distributors.py index 77a895be29..b0ef9ca291 100644 --- a/src/palace/manager/api/opds_for_distributors.py +++ b/src/palace/manager/api/opds_for_distributors.py @@ -38,7 +38,7 @@ RightsStatus, ) from palace.manager.sqlalchemy.model.patron import Loan, Patron -from palace.manager.sqlalchemy.model.resource import HttpResponseTuple, Hyperlink +from palace.manager.sqlalchemy.model.resource import Hyperlink from palace.manager.sqlalchemy.util import get_one from palace.manager.util import base64 from palace.manager.util.datetime_helpers import utc_now @@ -446,7 +446,7 @@ def __init__( self.api = OPDSForDistributorsAPI(_db, collection) - def _get(self, url: str, headers: Mapping[str, str]) -> HttpResponseTuple: + def _get(self, url: str, headers: Mapping[str, str]) -> Response: """Make a normal HTTP request for an OPDS feed, but add in an auth header with the credentials for the collection. """ diff --git a/src/palace/manager/core/opds2_import.py b/src/palace/manager/core/opds2_import.py index 8f03bf6c94..3448be487d 100644 --- a/src/palace/manager/core/opds2_import.py +++ b/src/palace/manager/core/opds2_import.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Iterable, Mapping from datetime import datetime from io import BytesIO, StringIO from typing import TYPE_CHECKING, Any @@ -63,11 +63,7 @@ RightsStatus, ) from palace.manager.sqlalchemy.model.patron import Patron -from palace.manager.sqlalchemy.model.resource import ( - HttpResponseTuple, - Hyperlink, - Representation, -) +from palace.manager.sqlalchemy.model.resource import Hyperlink, Representation from palace.manager.util.datetime_helpers import utc_now from palace.manager.util.http import HTTP, BadResponseException from palace.manager.util.opds_writer import OPDSFeed @@ -282,7 +278,6 @@ def __init__( collection: Collection, parser: RWPMManifestParser, data_source_name: str | None = None, - http_get: Callable[..., HttpResponseTuple] | None = None, ): """Initialize a new instance of OPDS2Importer class. @@ -298,7 +293,7 @@ def __init__( NOTE: If `collection` is provided, its .data_source will take precedence over any value provided here. This is only for use when you are importing OPDS metadata without any particular Collection in mind. """ - super().__init__(db, collection, data_source_name, http_get) + super().__init__(db, collection, data_source_name) self._parser = parser self.ignored_identifier_types = self.settings.ignored_identifier_types diff --git a/src/palace/manager/core/opds_import.py b/src/palace/manager/core/opds_import.py index e4d1955f23..00d30f2215 100644 --- a/src/palace/manager/core/opds_import.py +++ b/src/palace/manager/core/opds_import.py @@ -19,6 +19,7 @@ from flask_babel import lazy_gettext as _ from lxml import etree from pydantic import AnyHttpUrl +from requests import Response from sqlalchemy.orm.session import Session from palace.manager.api.circulation import ( @@ -75,11 +76,7 @@ ) from palace.manager.sqlalchemy.model.measurement import Measurement from palace.manager.sqlalchemy.model.patron import Patron -from palace.manager.sqlalchemy.model.resource import ( - HttpResponseTuple, - Hyperlink, - Representation, -) +from palace.manager.sqlalchemy.model.resource import Hyperlink from palace.manager.sqlalchemy.util import get_one from palace.manager.util import base64 from palace.manager.util.datetime_helpers import datetime_utc, to_utc, utc_now @@ -382,7 +379,6 @@ def __init__( _db: Session, collection: Collection, data_source_name: str | None, - http_get: Callable[..., HttpResponseTuple] | None = None, ): self._db = _db if collection.id is None: @@ -401,11 +397,6 @@ def __init__( "Cannot perform an OPDS import on a Collection that has no associated DataSource!" ) self.data_source_name = data_source_name - - # In general, we are cautious when mirroring resources so that - # we don't, e.g. accidentally get our IP banned from - # gutenberg.org. - self.http_get = http_get or Representation.cautious_http_get self.settings = integration_settings_load( self.settings_class(), collection.integration_configuration ) @@ -690,7 +681,6 @@ def __init__( _db: Session, collection: Collection, data_source_name: str | None = None, - http_get: Callable[..., HttpResponseTuple] | None = None, ): """:param collection: LicensePools created by this OPDS import will be associated with the given Collection. If this is None, @@ -703,19 +693,11 @@ def __init__( .data_source will take precedence over any value provided here. This is only for use when you are importing OPDS metadata without any particular Collection in mind. - - :param http_get: Use this method to make an HTTP GET request. This - can be replaced with a stub method for testing purposes. """ super().__init__(_db, collection, data_source_name) self.primary_identifier_source = self.settings.primary_identifier_source - # In general, we are cautious when mirroring resources so that - # we don't, e.g. accidentally get our IP banned from - # gutenberg.org. - self.http_get = http_get or Representation.cautious_http_get - def extract_next_links(self, feed: str | bytes | FeedParserDict) -> list[str]: if isinstance(feed, (bytes, str)): parsed = feedparser.parse(feed) @@ -754,7 +736,7 @@ def extract_feed_data( ) # gets: medium, measurements, links, contributors, etc. xml_data_meta, xml_failures = self.extract_metadata_from_elementtree( - feed, data_source=data_source, feed_url=feed_url, do_get=self.http_get + feed, data_source=data_source, feed_url=feed_url ) # translate the id in failures to identifier.urn @@ -979,7 +961,6 @@ def extract_metadata_from_elementtree( feed: bytes | str, data_source: DataSource, feed_url: str | None = None, - do_get: Callable[..., HttpResponseTuple] | None = None, ) -> tuple[dict[str, Any], dict[str, CoverageFailure]]: """Parse the OPDS as XML and extract all author and subject information, as well as ratings and medium. @@ -1027,7 +1008,7 @@ def extract_metadata_from_elementtree( # Then turn Atom tags into Metadata objects. for entry in parser._xpath(root, "/atom:feed/atom:entry"): identifier, detail, failure_entry = cls.detail_for_elementtree_entry( - parser, entry, data_source, feed_url, do_get=do_get + parser, entry, data_source, feed_url ) if identifier: if failure_entry: @@ -1309,7 +1290,6 @@ def detail_for_elementtree_entry( entry_tag: Element, data_source: DataSource, feed_url: str | None = None, - do_get: Callable[..., HttpResponseTuple] | None = None, ) -> tuple[str | None, dict[str, Any] | None, CoverageFailure | None]: """Turn an tag into a dictionary of metadata that can be used as keyword arguments to the Metadata contructor. @@ -1324,9 +1304,7 @@ def detail_for_elementtree_entry( identifier = identifier.text try: - data = cls._detail_for_elementtree_entry( - parser, entry_tag, feed_url, do_get=do_get - ) + data = cls._detail_for_elementtree_entry(parser, entry_tag, feed_url) return identifier, data, None except Exception as e: @@ -1343,7 +1321,6 @@ def _detail_for_elementtree_entry( parser: OPDSXMLParser, entry_tag: Element, feed_url: str | None = None, - do_get: Callable[..., HttpResponseTuple] | None = None, ) -> dict[str, Any]: """Helper method that extracts metadata and circulation data from an elementtree entry. This method can be overridden in tests to check that callers handle things @@ -1734,7 +1711,7 @@ def __init__( self._feed_base_url = f"{parsed_url.scheme}://{parsed_url.hostname}{(':' + str(parsed_url.port)) if parsed_url.port else ''}/" super().__init__(_db, collection) - def _get(self, url: str, headers: Mapping[str, str]) -> HttpResponseTuple: + def _get(self, url: str, headers: Mapping[str, str]) -> Response: """Make the sort of HTTP request that's normal for an OPDS feed. Long timeout, raise error on anything but 2xx or 3xx. @@ -1748,8 +1725,7 @@ def _get(self, url: str, headers: Mapping[str, str]) -> HttpResponseTuple: ) if not url.startswith("http"): url = urljoin(self._feed_base_url, url) - response = HTTP.get_with_timeout(url, headers=headers, **kwargs) - return response.status_code, response.headers, response.content + return HTTP.get_with_timeout(url, headers=headers, **kwargs) def _get_accept_header(self) -> str: return ",".join( @@ -1892,7 +1868,7 @@ def _verify_media_type( raise BadResponseException(url, message=message, status_code=status_code) def follow_one_link( - self, url: str, do_get: Callable[..., HttpResponseTuple] | None = None + self, url: str, do_get: Callable[..., Response] | None = None ) -> tuple[list[str], bytes | None]: """Download a representation of a URL and extract the useful information. @@ -1903,7 +1879,10 @@ def follow_one_link( """ self.log.info("Following next link: %s", url) get = do_get or self._get - status_code, headers, feed = get(url, {}) + resp = get(url, {}) + feed = resp.content + status_code = resp.status_code + headers = resp.headers self._verify_media_type(url, status_code, headers, feed) diff --git a/tests/fixtures/odl.py b/tests/fixtures/odl.py index e089b298f1..759ce55ce7 100644 --- a/tests/fixtures/odl.py +++ b/tests/fixtures/odl.py @@ -8,6 +8,7 @@ import pytest from jinja2 import Template +from requests import Response from palace.manager.api.circulation import LoanInfo from palace.manager.api.odl.api import OPDS2WithODLApi @@ -23,10 +24,10 @@ LicensePoolDeliveryMechanism, ) from palace.manager.sqlalchemy.model.patron import Loan, Patron -from palace.manager.sqlalchemy.model.resource import HttpResponseTuple from palace.manager.sqlalchemy.model.work import Work from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.files import FilesFixture, OPDS2WithODLFilesFixture +from tests.mocks.mock import MockRequestsResponse from tests.mocks.odl import MockOPDS2WithODLApi @@ -264,8 +265,8 @@ def __init__( http_get=self.get_response, ) - def get_response(self, *args: Any, **kwargs: Any) -> HttpResponseTuple: - return 200, {}, self.responses.pop(0) + def get_response(self, *args: Any, **kwargs: Any) -> Response: + return MockRequestsResponse(200, content=self.responses.pop(0)) def queue_response(self, item: LicenseInfoHelper | str | bytes) -> None: if isinstance(item, LicenseInfoHelper): diff --git a/tests/manager/api/controller/test_loan.py b/tests/manager/api/controller/test_loan.py index 57db5b8d14..cf7462da2b 100644 --- a/tests/manager/api/controller/test_loan.py +++ b/tests/manager/api/controller/test_loan.py @@ -69,7 +69,7 @@ from tests.fixtures.redis import RedisFixture from tests.fixtures.services import ServicesFixture from tests.mocks.circulation import MockPatronActivityCirculationAPI -from tests.mocks.mock import DummyHTTPClient +from tests.mocks.mock import MockRepresentationHTTPClient class LoanFixture(CirculationControllerFixture): @@ -423,7 +423,7 @@ def test_borrow_success( # external request to obtain the book. loan_fixture.pool.open_access = False - http = DummyHTTPClient() + http = MockRepresentationHTTPClient() fulfillment = FulfillmentInfo( loan_fixture.pool.collection, @@ -602,7 +602,7 @@ def test_borrow_and_fulfill_with_streaming_delivery_mechanism( assert None == loan.fulfillment # We can still use the other mechanism too. - http = DummyHTTPClient() + http = MockRepresentationHTTPClient() http.queue_response(200, content="I am an ACSM file") loan_fixture.manager.d_circulation.queue_fulfill( diff --git a/tests/manager/api/metadata/test_novelist.py b/tests/manager/api/metadata/test_novelist.py index 09b9a5a0cb..8d2e3de57a 100644 --- a/tests/manager/api/metadata/test_novelist.py +++ b/tests/manager/api/metadata/test_novelist.py @@ -15,7 +15,7 @@ from palace.manager.util.http import HTTP from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.files import FilesFixture -from tests.mocks.mock import DummyHTTPClient, MockRequestsResponse +from tests.mocks.mock import MockRepresentationHTTPClient, MockRequestsResponse class NoveListFilesFixture(FilesFixture): @@ -216,7 +216,7 @@ def test_get_series_information(self, novelist_fixture: NoveListFixture): def test_lookup(self, novelist_fixture: NoveListFixture): # Test the lookup() method. - h = DummyHTTPClient() + h = MockRepresentationHTTPClient() h.queue_response(200, "text/html", content="yay") novelist = novelist_fixture.novelist diff --git a/tests/manager/api/odl/test_importer.py b/tests/manager/api/odl/test_importer.py index f17b3ee903..07803d48ef 100644 --- a/tests/manager/api/odl/test_importer.py +++ b/tests/manager/api/odl/test_importer.py @@ -10,6 +10,7 @@ import dateutil import pytest from freezegun import freeze_time +from requests import Response from webpub_manifest_parser.odl.ast import ODLPublication from webpub_manifest_parser.odl.semantic import ( ODL_PUBLICATION_MUST_CONTAIN_EITHER_LICENSES_OR_OA_ACQUISITION_LINK_ERROR, @@ -31,7 +32,7 @@ LicensePool, LicenseStatus, ) -from palace.manager.sqlalchemy.model.resource import HttpResponseTuple, Hyperlink +from palace.manager.sqlalchemy.model.resource import Hyperlink from palace.manager.sqlalchemy.model.work import Work from palace.manager.util import datetime_helpers from palace.manager.util.datetime_helpers import utc_now @@ -41,6 +42,7 @@ LicenseInfoHelper, OPDS2WithODLImporterFixture, ) +from tests.mocks.mock import MockHTTPClient class TestOPDS2WithODLImporter: @@ -915,50 +917,50 @@ def license_info_dict() -> dict[str, Any]: def test_fetch_license_info(self): """Ensure that OPDS2WithODLImporter correctly retrieves license data from an OPDS2 feed.""" - responses: list[HttpResponseTuple] = [] - requests: list[str] = [] - - def get(url: str, *args: Any, **kwargs: Any) -> HttpResponseTuple: - requests.append(url) - return responses.pop(0) + http = MockHTTPClient() # Bad status code - responses.append((400, {}, b"Bad Request")) + http.queue_response(400, content=b"Bad Request") assert ( - OPDS2WithODLImporter.fetch_license_info("http://example.org/feed", get) + OPDS2WithODLImporter.fetch_license_info( + "http://example.org/feed", http.do_get + ) is None ) - assert len(requests) == 1 - assert requests.pop() == "http://example.org/feed" + assert len(http.requests) == 1 + assert http.requests.pop() == "http://example.org/feed" # 200 status - json decodes body and returns it - responses.append((200, {}, json.dumps(["a", "b"]).encode("utf-8"))) + http.queue_response(200, content=json.dumps(["a", "b"])) assert OPDS2WithODLImporter.fetch_license_info( - "http://example.org/feed", get + "http://example.org/feed", http.do_get ) == [ "a", "b", ] - assert len(requests) == 1 - assert requests.pop() == "http://example.org/feed" + assert len(http.requests) == 1 + assert http.requests.pop() == "http://example.org/feed" # 201 status - json decodes body and returns it - responses.append((201, {}, json.dumps({"test": "123"}).encode("utf-8"))) + http.queue_response(201, content=json.dumps({"test": "123"})) assert OPDS2WithODLImporter.fetch_license_info( - "http://example.org/feed", get + "http://example.org/feed", http.do_get ) == {"test": "123"} - assert len(requests) == 1 - assert requests.pop() == "http://example.org/feed" + assert len(http.requests) == 1 + assert http.requests.pop() == "http://example.org/feed" def test_get_license_data(self, monkeypatch: pytest.MonkeyPatch): expires = utc_now() + datetime.timedelta(days=1) responses: list[tuple[int, str]] = [] - def get(url: str, *args: Any, **kwargs: Any) -> HttpResponseTuple: - resp = responses.pop(0) - return resp[0], {}, resp[1].encode("utf-8") + def get(url: str, *args: Any, **kwargs: Any) -> Response: + status_code, body = responses.pop(0) + resp = Response() + resp.status_code = status_code + resp._content = body.encode("utf-8") + return resp def get_license_data() -> LicenseData | None: return OPDS2WithODLImporter.get_license_data( diff --git a/tests/manager/api/test_enki.py b/tests/manager/api/test_enki.py index 2840a64174..9a59b82447 100644 --- a/tests/manager/api/test_enki.py +++ b/tests/manager/api/test_enki.py @@ -594,7 +594,7 @@ def test_patron_activity(self, enki_test_fixture: EnkiTestFixure): def test_patron_activity_failure(self, enki_test_fixture: EnkiTestFixure): db = enki_test_fixture.db patron = db.patron() - enki_test_fixture.api.queue_response(404, "No such patron") + enki_test_fixture.api.queue_response(404, content="No such patron") collect = lambda: list(enki_test_fixture.api.patron_activity(patron, "pin")) pytest.raises(PatronAuthorizationFailedException, collect) diff --git a/tests/manager/api/test_opds_for_distributors.py b/tests/manager/api/test_opds_for_distributors.py index a913963769..7364401b20 100644 --- a/tests/manager/api/test_opds_for_distributors.py +++ b/tests/manager/api/test_opds_for_distributors.py @@ -35,6 +35,7 @@ from palace.manager.util.opds_writer import OPDSFeed from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.files import FilesFixture +from tests.mocks.mock import MockRequestsResponse from tests.mocks.opds_for_distributors import MockOPDSForDistributorsAPI @@ -728,7 +729,9 @@ class MockOPDSForDistributorsReaperMonitor(OPDSForDistributorsReaperMonitor): """An OPDSForDistributorsReaperMonitor that overrides _get.""" def _get(self, url, headers): - return (200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, feed) + return MockRequestsResponse( + 200, {"content-type": OPDSFeed.ACQUISITION_FEED_TYPE}, feed + ) data_source = DataSource.lookup( opds_dist_api_fixture.db.session, "Biblioboard", autocreate=True diff --git a/tests/manager/api/test_overdrive.py b/tests/manager/api/test_overdrive.py index a74a8be5b5..e378691ceb 100644 --- a/tests/manager/api/test_overdrive.py +++ b/tests/manager/api/test_overdrive.py @@ -76,7 +76,7 @@ from tests.fixtures.files import FilesFixture from tests.fixtures.library import LibraryFixture from tests.fixtures.webserver import MockAPIServer, MockAPIServerResponse -from tests.mocks.mock import DummyHTTPClient, MockRequestsResponse +from tests.mocks.mock import MockHTTPClient, MockRequestsResponse from tests.mocks.overdrive import MockOverdriveAPI if TYPE_CHECKING: @@ -1155,7 +1155,7 @@ def process_error_response(message): # Same if the error message is missing or the response can't be # processed. pytest.raises(CannotHold, process_error_response, dict()) - pytest.raises(CannotHold, process_error_response, None) + pytest.raises(CannotHold, process_error_response, json.dumps(None)) # Same if the error code isn't in the 4xx or 2xx range # (which shouldn't happen in real life). @@ -1320,18 +1320,17 @@ def _extract_early_return_url(self, *args): # The first will be to the fulfill link returned by our mock # get_fulfillment_link. The response to this request is a # redirect that includes an early return link. - http = DummyHTTPClient() - http.responses.append( - MockRequestsResponse( - 302, dict(location="http://fulfill-this-book/?or=return-early") - ) + http = MockHTTPClient() + http.queue_response( + 302, + other_headers=dict(location="http://fulfill-this-book/?or=return-early"), ) # The second HTTP request made will be to the early return # link 'extracted' from that link by our mock # _extract_early_return_url. The response here is a copy of # the actual response Overdrive sends in this situation. - http.responses.append(MockRequestsResponse(200, content="Success")) + http.queue_response(200, content="Success") # Do the thing. success = overdrive.perform_early_return(patron, pin, loan, http.do_get) @@ -1365,9 +1364,10 @@ def _extract_early_return_url(self, *args): # overdrive._extract_early_return_url_call = None overdrive.EARLY_RETURN_URL = None # type: ignore - http.responses.append( - MockRequestsResponse(302, dict(location="http://fulfill-this-book/")) + http.queue_response( + 302, other_headers=dict(location="http://fulfill-this-book/") ) + success = overdrive.perform_early_return(patron, pin, loan, http.do_get) assert False == success @@ -1393,12 +1393,11 @@ def _extract_early_return_url(self, *args): # If the final attempt to hit the return URL doesn't result # in a 200 status code, perform_early_return has no effect. - http.responses.append( - MockRequestsResponse( - 302, dict(location="http://fulfill-this-book/?or=return-early") - ) + http.queue_response( + 302, + other_headers=dict(location="http://fulfill-this-book/?or=return-early"), ) - http.responses.append(MockRequestsResponse(401, content="Unauthorized!")) + http.queue_response(401, content="Unauthorized!") success = overdrive.perform_early_return(patron, pin, loan, http.do_get) assert False == success diff --git a/tests/manager/core/test_opds_import.py b/tests/manager/core/test_opds_import.py index abe68bbc9e..c87133eb5c 100644 --- a/tests/manager/core/test_opds_import.py +++ b/tests/manager/core/test_opds_import.py @@ -59,7 +59,7 @@ from palace.manager.util.opds_writer import AtomFeed, OPDSFeed, OPDSMessage from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.files import OPDSFilesFixture -from tests.mocks.mock import DummyHTTPClient +from tests.mocks.mock import MockHTTPClient, MockRequestsResponse class DoomedOPDSImporter(OPDSImporter): @@ -118,25 +118,6 @@ def opds_importer_fixture( class TestOPDSImporter: - def test_constructor(self, opds_importer_fixture: OPDSImporterFixture): - data, db, session = ( - opds_importer_fixture, - opds_importer_fixture.db, - opds_importer_fixture.db.session, - ) - - # The default way of making HTTP requests is with - # Representation.cautious_http_get. - importer = opds_importer_fixture.importer() - assert Representation.cautious_http_get == importer.http_get - - # But you can pass in anything you want. - do_get = MagicMock() - importer = OPDSImporter( - session, collection=db.default_collection(), http_get=do_get - ) - assert do_get == importer.http_get - def test_data_source_autocreated(self, opds_importer_fixture: OPDSImporterFixture): data, db, session = ( opds_importer_fixture, @@ -1725,7 +1706,9 @@ def test_feed_contains_new_data( class MockOPDSImportMonitor(OPDSImportMonitor): def _get(self, url, headers): - return 200, {"content-type": AtomFeed.ATOM_TYPE}, feed + return MockRequestsResponse( + 200, {"content-type": AtomFeed.ATOM_TYPE}, feed + ) data_source_name = "OPDS" collection = db.collection( @@ -1825,7 +1808,7 @@ def test_follow_one_link(self, opds_importer_fixture: OPDSImporterFixture): ) feed = data.content_server_mini_feed - http = DummyHTTPClient() + http = MockHTTPClient() # If there's new data, follow_one_link extracts the next links. def follow(): diff --git a/tests/manager/sqlalchemy/model/test_resource.py b/tests/manager/sqlalchemy/model/test_resource.py index 3d5e26e935..380dd1bfd9 100644 --- a/tests/manager/sqlalchemy/model/test_resource.py +++ b/tests/manager/sqlalchemy/model/test_resource.py @@ -13,7 +13,7 @@ TemporaryDirectoryConfigurationFixture, ) from tests.fixtures.files import SampleCoversFixture -from tests.mocks.mock import DummyHTTPClient, MockRequestsResponse +from tests.mocks.mock import MockRepresentationHTTPClient, MockRequestsResponse class TestHyperlink: @@ -289,7 +289,7 @@ def test_unicode_content_is_none_when_decoding_is_impossible( assert None == representation.unicode_content def test_presumed_media_type(self, db: DatabaseTransactionFixture): - h = DummyHTTPClient() + h = MockRepresentationHTTPClient() # In the absence of a content-type header, the presumed_media_type # takes over. @@ -328,7 +328,7 @@ def test_presumed_media_type(self, db: DatabaseTransactionFixture): assert "text/plain" == representation.media_type def test_404_creates_cachable_representation(self, db: DatabaseTransactionFixture): - h = DummyHTTPClient() + h = MockRepresentationHTTPClient() h.queue_response(404) url = db.fresh_url() @@ -340,7 +340,7 @@ def test_404_creates_cachable_representation(self, db: DatabaseTransactionFixtur assert representation == representation2 def test_302_creates_cachable_representation(self, db: DatabaseTransactionFixture): - h = DummyHTTPClient() + h = MockRepresentationHTTPClient() h.queue_response(302) url = db.fresh_url() @@ -354,7 +354,7 @@ def test_302_creates_cachable_representation(self, db: DatabaseTransactionFixtur def test_500_creates_uncachable_representation( self, db: DatabaseTransactionFixture ): - h = DummyHTTPClient() + h = MockRepresentationHTTPClient() h.queue_response(500) url = db.fresh_url() representation, cached = Representation.get(db.session, url, do_get=h.do_get) @@ -367,7 +367,7 @@ def test_500_creates_uncachable_representation( def test_response_reviewer_impacts_representation( self, db: DatabaseTransactionFixture ): - h = DummyHTTPClient() + h = MockRepresentationHTTPClient() h.queue_response(200, media_type="text/html") def reviewer(response): @@ -496,7 +496,7 @@ def test_default_filename(self, db: DatabaseTransactionFixture): assert "cover.png" == filename def test_cautious_http_get(self): - h = DummyHTTPClient() + h = MockRepresentationHTTPClient() h.queue_response(200, content="yay") # If the domain is obviously safe, the GET request goes through, @@ -660,7 +660,7 @@ def normalize(self, url): normalizer = Normalizer() - h = DummyHTTPClient() + h = MockRepresentationHTTPClient() h.queue_response(200, content="yay") original_url = "http://url/?sid=12345" diff --git a/tests/mocks/mock.py b/tests/mocks/mock.py index 00a42a8562..239cdb6ec4 100644 --- a/tests/mocks/mock.py +++ b/tests/mocks/mock.py @@ -1,5 +1,8 @@ import json import logging +from typing import Any + +from requests import Request, Response from palace.manager.core.coverage import ( BibliographicCoverageProvider, @@ -9,6 +12,7 @@ ) from palace.manager.core.opds_import import OPDSAPI from palace.manager.sqlalchemy.model.datasource import DataSource +from palace.manager.sqlalchemy.model.resource import HttpResponseTuple def _normalize_level(level): @@ -212,14 +216,42 @@ def process_batch(self, batch): return [] -class DummyHTTPClient: - def __init__(self): - self.responses = [] - self.requests = [] +class MockHTTPClient: + def __init__(self) -> None: + self.responses: list[Response] = [] + self.requests: list[str] = [] def queue_response( - self, response_code, media_type="text/html", other_headers=None, content="" + self, + response_code: int, + media_type: str | None = None, + other_headers: dict[str, str] | None = None, + content: str | bytes = "", ): + """Queue a response of the type produced by HTTP.get_with_timeout.""" + headers = dict(other_headers or {}) + if media_type: + headers["Content-Type"] = media_type + + self.responses.append(MockRequestsResponse(response_code, headers, content)) + + def do_get(self, url: str, *args: Any, **kwargs: Any) -> Response: + self.requests.append(url) + return self.responses.pop(0) + + +class MockRepresentationHTTPClient: + def __init__(self) -> None: + self.responses: list[HttpResponseTuple] = [] + self.requests: list[str | tuple[str, str]] = [] + + def queue_response( + self, + response_code: int, + media_type: str | None = "text/html", + other_headers: dict[str, str] | None = None, + content: str | bytes = "", + ) -> None: """Queue a response of the type produced by Representation.simple_http_get. """ @@ -235,21 +267,13 @@ def queue_response( headers[k.lower()] = v self.responses.append((response_code, headers, content)) - def queue_requests_response( - self, response_code, media_type="text/html", other_headers=None, content="" - ): - """Queue a response of the type produced by HTTP.get_with_timeout.""" - headers = dict(other_headers or {}) - if media_type: - headers["Content-Type"] = media_type - response = MockRequestsResponse(response_code, headers, content) - self.responses.append(response) - - def do_get(self, url, *args, **kwargs): + def do_get(self, url: str, *args: Any, **kwargs: Any) -> HttpResponseTuple: self.requests.append(url) return self.responses.pop(0) - def do_post(self, url, data, *wargs, **kwargs): + def do_post( + self, url: str, data: str, *wargs: Any, **kwargs: Any + ) -> HttpResponseTuple: self.requests.append((url, data)) return self.responses.pop(0) @@ -265,41 +289,40 @@ def __init__(self, url, method="GET", headers=None): self.headers = headers or dict() -class MockRequestsResponse: +class MockRequestsResponse(Response): """A mock object that simulates an HTTP response from the `requests` library. """ - def __init__(self, status_code, headers={}, content=None, url=None, request=None): + def __init__( + self, + status_code: int, + headers: dict[str, str] | None = None, + content: Any = None, + url: str | None = None, + request: Request | None = None, + ): + super().__init__() + self.status_code = status_code - self.headers = headers + if headers is not None: + for k, v in headers.items(): + self.headers[k] = v + # We want to enforce that the mocked content is a bytestring # just like a real response. - if content and isinstance(content, str): - self.content = content.encode("utf-8") - else: - self.content = content + if content is not None: + if isinstance(content, str): + content_bytes = content.encode("utf-8") + elif isinstance(content, bytes): + content_bytes = content + else: + content_bytes = json.dumps(content).encode("utf-8") + self._content = content_bytes + if request and not url: url = request.url self.url = url or "http://url/" self.encoding = "utf-8" - self.request = request - - def json(self): - content = self.content - # The queued content might be a JSON string or it might - # just be the object you'd get from loading a JSON string. - if isinstance(content, (str, bytes)): - content = json.loads(self.content) - return content - - @property - def text(self): - if isinstance(self.content, bytes): - return self.content.decode("utf8") - return self.content - - def raise_for_status(self): - """Null implementation of raise_for_status, a method - implemented by real requests Response objects. - """ + if request: + self.request = request.prepare()