diff --git a/api/app.py b/api/app.py index a15c997bb3..d380beed10 100644 --- a/api/app.py +++ b/api/app.py @@ -17,7 +17,7 @@ SessionManager, pg_advisory_lock, ) -from core.service.container import create_container +from core.service.container import container_instance from core.util import LanguageCodes from core.util.cache import CachedData from scripts import InstanceInitializationScript @@ -73,7 +73,7 @@ def initialize_circulation_manager(): pass else: if getattr(app, "manager", None) is None: - container = create_container() + container = container_instance() try: app.manager = CirculationManager(app._db, container) except Exception: diff --git a/api/google_analytics_provider.py b/api/google_analytics_provider.py index eebf0a3a5d..09c82a72c8 100644 --- a/api/google_analytics_provider.py +++ b/api/google_analytics_provider.py @@ -6,6 +6,7 @@ from flask_babel import lazy_gettext as _ from core.model import ConfigurationSetting, ExternalIntegration, Session +from core.service.container import Services from core.util.http import HTTP from .config import CannotLoadConfiguration @@ -63,7 +64,7 @@ class GoogleAnalyticsProvider: {"key": TRACKING_ID, "label": _("Tracking ID"), "required": True}, ] - def __init__(self, integration, library=None): + def __init__(self, integration, services: Services, library=None): _db = Session.object_session(integration) if not library: raise CannotLoadConfiguration( diff --git a/api/s3_analytics_provider.py b/api/s3_analytics_provider.py index 5a29b04089..8ddff535cc 100644 --- a/api/s3_analytics_provider.py +++ b/api/s3_analytics_provider.py @@ -4,7 +4,6 @@ import string from typing import Dict, Optional -from dependency_injector.wiring import Provide, inject from flask_babel import lazy_gettext as _ from sqlalchemy.orm import Session @@ -12,7 +11,7 @@ from core.local_analytics_provider import LocalAnalyticsProvider from core.model import Library, LicensePool, MediaTypes from core.model.configuration import ConfigurationGrouping -from core.service.storage.container import Storage +from core.service.container import Services from core.service.storage.s3 import S3Service @@ -30,15 +29,13 @@ class S3AnalyticsProvider(LocalAnalyticsProvider): LocalAnalyticsProvider.SETTINGS + S3AnalyticsProviderConfiguration.to_settings() ) - @inject def __init__( self, integration, + services: Services, library=None, - s3_service: Optional[S3Service] = Provide[Storage.analytics], ): - self.s3_service = s3_service - super().__init__(integration, library) + super().__init__(integration, services, library) @staticmethod def _create_event_object( @@ -252,12 +249,13 @@ def _get_storage(self) -> S3Service: :return: StorageServiceBase object """ - if not self.s3_service: + s3_storage_service = self.services.storage.analytics() + if s3_storage_service is None: raise CannotLoadConfiguration( "No storage service is configured with an analytics bucket." ) - return self.s3_service + return s3_storage_service Provider = S3AnalyticsProvider diff --git a/bin/odl2_import_monitor b/bin/odl2_import_monitor index 04048fdd77..ca1871f437 100755 --- a/bin/odl2_import_monitor +++ b/bin/odl2_import_monitor @@ -11,8 +11,6 @@ sys.path.append(os.path.abspath(package_dir)) from webpub_manifest_parser.odl import ODLFeedParserFactory from api.odl2 import ODL2Importer, ODL2ImportMonitor - -# NOTE: We need to import it explicitly to initialize MirrorUploader.IMPLEMENTATION_REGISTRY from core.opds2_import import RWPMManifestParser from core.scripts import RunCollectionMonitorScript diff --git a/bin/odl2_schema_validate b/bin/odl2_schema_validate index b51f1732cd..f7972efc14 100755 --- a/bin/odl2_schema_validate +++ b/bin/odl2_schema_validate @@ -11,8 +11,6 @@ sys.path.append(os.path.abspath(package_dir)) from webpub_manifest_parser.odl import ODLFeedParserFactory from api.odl2 import ODL2Importer - -# NOTE: We need to import it explicitly to initialize MirrorUploader.IMPLEMENTATION_REGISTRY from core.opds2_import import RWPMManifestParser from core.opds_schema import ODL2SchemaValidation from core.scripts import RunCollectionMonitorScript diff --git a/bin/odl_import_monitor b/bin/odl_import_monitor index 57bbc86c1c..aa1b5cd332 100755 --- a/bin/odl_import_monitor +++ b/bin/odl_import_monitor @@ -9,8 +9,6 @@ package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) from api.odl import ODLImporter, ODLImportMonitor - -# NOTE: We need to import it explicitly to initialize MirrorUploader.IMPLEMENTATION_REGISTRY from core.scripts import RunCollectionMonitorScript RunCollectionMonitorScript( diff --git a/bin/opds2_import_monitor b/bin/opds2_import_monitor index fd55eb5af9..3223ba6cd0 100755 --- a/bin/opds2_import_monitor +++ b/bin/opds2_import_monitor @@ -9,7 +9,6 @@ sys.path.append(os.path.abspath(package_dir)) from webpub_manifest_parser.opds2 import OPDS2FeedParserFactory -# NOTE: We need to import it explicitly to initialize MirrorUploader.IMPLEMENTATION_REGISTRY from core.model import ExternalIntegration from core.opds2_import import OPDS2Importer, OPDS2ImportMonitor, RWPMManifestParser from core.scripts import OPDSImportScript diff --git a/bin/opds2_schema_validate b/bin/opds2_schema_validate index 6a3977b5cc..070507d428 100755 --- a/bin/opds2_schema_validate +++ b/bin/opds2_schema_validate @@ -10,7 +10,6 @@ sys.path.append(os.path.abspath(package_dir)) from webpub_manifest_parser.opds2 import OPDS2FeedParserFactory -# NOTE: We need to import it explicitly to initialize MirrorUploader.IMPLEMENTATION_REGISTRY from core.model.configuration import ExternalIntegration from core.opds2_import import OPDS2Importer, RWPMManifestParser from core.opds_schema import OPDS2SchemaValidation diff --git a/bin/opds_import_monitor b/bin/opds_import_monitor index 6d2ed67297..b18022933c 100755 --- a/bin/opds_import_monitor +++ b/bin/opds_import_monitor @@ -8,7 +8,6 @@ bin_dir = os.path.split(__file__)[0] package_dir = os.path.join(bin_dir, "..") sys.path.append(os.path.abspath(package_dir)) -# NOTE: We need to import it explicitly to initialize MirrorUploader.IMPLEMENTATION_REGISTRY from core.scripts import OPDSImportScript OPDSImportScript().run() diff --git a/core/analytics.py b/core/analytics.py index d8b4f94391..4c1d9e27f3 100644 --- a/core/analytics.py +++ b/core/analytics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import logging from collections import defaultdict @@ -7,6 +9,7 @@ from .config import CannotLoadConfiguration from .model import ExternalIntegration +from .service.container import container_instance from .util.datetime_helpers import utc_now from .util.log import log_elapsed_time @@ -26,7 +29,7 @@ class Analytics: GLOBAL_ENABLED: Optional[bool] = None LIBRARY_ENABLED: Set[int] = set() - def __new__(cls, _db, refresh=False) -> "Analytics": + def __new__(cls, _db: Session, refresh: bool = False) -> Analytics: instance = cls._singleton_instance if instance is None: refresh = True @@ -44,15 +47,16 @@ def _reset_singleton_instance(cls): cls._singleton_instance = None @log_elapsed_time(log_method=log.debug, message_prefix="Initializing instance") - def _initialize_instance(self, _db): + def _initialize_instance(self, _db: Session) -> None: """Initialize an instance (usually the singleton) of the class. We don't use __init__ because it would be run whether or not a new instance were instantiated. """ + services = container_instance() sitewide_providers = [] library_providers = defaultdict(list) - initialization_exceptions: Dict[int, Exception] = {} + initialization_exceptions: Dict[int, Exception | str] = {} global_enabled = False library_enabled = set() # Find a list of all the ExternalIntegrations set up with a @@ -68,12 +72,12 @@ def _initialize_instance(self, _db): provider_class = self._provider_class_from_module(module) if provider_class: if not libraries: - provider = provider_class(integration) + provider = provider_class(integration, services) sitewide_providers.append(provider) global_enabled = True else: for library in libraries: - provider = provider_class(integration, library) + provider = provider_class(integration, services, library) library_providers[library.id].append(provider) library_enabled.add(library.id) else: diff --git a/core/local_analytics_provider.py b/core/local_analytics_provider.py index 207972591c..4b40c497bf 100644 --- a/core/local_analytics_provider.py +++ b/core/local_analytics_provider.py @@ -2,6 +2,7 @@ from sqlalchemy.orm.session import Session from .model import CirculationEvent, ExternalIntegration, create, get_one +from .service.container import Services class LocalAnalyticsProvider: @@ -41,12 +42,13 @@ class LocalAnalyticsProvider: }, ] - def __init__(self, integration, library=None): + def __init__(self, integration, services: Services, library=None): self.integration_id = integration.id self.location_source = ( integration.setting(self.LOCATION_SOURCE).value or self.LOCATION_SOURCE_DISABLED ) + self.services = services if library: self.library_id = library.id else: diff --git a/core/marc.py b/core/marc.py index 2355a30455..3e4da0c2f1 100644 --- a/core/marc.py +++ b/core/marc.py @@ -743,6 +743,7 @@ def records( if not is_new: cached.representation = representation cached.end_time = end_time + representation.set_as_mirrored(upload.url) else: representation.mirror_exception = str(upload.exception) diff --git a/core/mock_analytics_provider.py b/core/mock_analytics_provider.py index 8742b94063..294b8b7d74 100644 --- a/core/mock_analytics_provider.py +++ b/core/mock_analytics_provider.py @@ -1,7 +1,7 @@ class MockAnalyticsProvider: """A mock analytics provider that keeps track of how many times it's called.""" - def __init__(self, integration=None, library=None): + def __init__(self, integration=None, services=None, library=None): self.count = 0 self.event = None if integration: diff --git a/core/scripts.py b/core/scripts.py index 5ad961af0a..fa6bd623bb 100644 --- a/core/scripts.py +++ b/core/scripts.py @@ -58,7 +58,7 @@ from .monitor import CollectionMonitor, ReaperMonitor from .opds_import import OPDSImporter, OPDSImportMonitor from .overdrive import OverdriveCoreAPI -from .service.container import Services, create_container +from .service.container import Services, container_instance from .util import fast_query_count from .util.datetime_helpers import strptime_utc, utc_now from .util.personal_names import contributor_name_match_ratio, display_name_to_sort_name @@ -125,7 +125,7 @@ def __init__(self, _db=None, services: Optional[Services] = None, *args, **kwarg self._session = _db if services is None: - services = create_container() + services = container_instance() self._services = services def run(self): diff --git a/core/service/container.py b/core/service/container.py index 23d5006d00..b204df6462 100644 --- a/core/service/container.py +++ b/core/service/container.py @@ -19,3 +19,19 @@ def create_container() -> Services: container = Services() container.config.from_dict({"storage": StorageConfiguration().dict()}) return container + + +_container_instance = None + + +def container_instance() -> Services: + # Create a singleton container instance, I'd like this to be used sparingly + # and eventually have it go away, but there are places in the code that + # are currently difficult to refactor to pass the container into the + # constructor. + # If at all possible please use the container that is stored in the CirculationManager + # or Scripts classes instead of using this function. + global _container_instance + if _container_instance is None: + _container_instance = create_container() + return _container_instance diff --git a/core/service/storage/configuration.py b/core/service/storage/configuration.py index 9d2292d5a1..1ff6f6b01d 100644 --- a/core/service/storage/configuration.py +++ b/core/service/storage/configuration.py @@ -1,7 +1,7 @@ from typing import Optional import boto3 -from pydantic import HttpUrl, parse_obj_as, validator +from pydantic import AnyHttpUrl, parse_obj_as, validator from core.service.configuration import ServiceConfiguration @@ -14,10 +14,10 @@ class StorageConfiguration(ServiceConfiguration): public_access_bucket: Optional[str] = None analytics_bucket: Optional[str] = None - endpoint_url: Optional[HttpUrl] = None + endpoint_url: Optional[AnyHttpUrl] = None - url_template: HttpUrl = parse_obj_as( - HttpUrl, "https://{bucket}.s3.{region}.amazonaws.com/{key}" + url_template: AnyHttpUrl = parse_obj_as( + AnyHttpUrl, "https://{bucket}.s3.{region}.amazonaws.com/{key}" ) @validator("region") diff --git a/core/service/storage/container.py b/core/service/storage/container.py index b8a1b004e9..54cf2db835 100644 --- a/core/service/storage/container.py +++ b/core/service/storage/container.py @@ -1,21 +1,11 @@ import boto3 from dependency_injector import providers -from dependency_injector.containers import DeclarativeContainer, WiringConfiguration +from dependency_injector.containers import DeclarativeContainer from core.service.storage.s3 import S3Service class Storage(DeclarativeContainer): - - # See https://python-dependency-injector.ets-labs.org/wiring.html - # This lists modules that contain markers that will be used to wire up - # dependencies from this container automatically. - wiring_config = WiringConfiguration( - modules=[ - "api.s3_analytics_provider", - ], - ) - config = providers.Configuration() s3_client = providers.Singleton( diff --git a/scripts.py b/scripts.py index 00689d45c5..5241ecbd16 100644 --- a/scripts.py +++ b/scripts.py @@ -6,6 +6,7 @@ from datetime import timedelta from pathlib import Path from typing import Optional +from unittest.mock import MagicMock from sqlalchemy import inspect from sqlalchemy.engine import Connection @@ -193,7 +194,7 @@ def __init__(self, _db=None, cmd_args=None, manager=None, *args, **kwargs): super().__init__(_db, *args, **kwargs) self.parse_args(cmd_args) if not manager: - manager = CirculationManager(self._db) + manager = CirculationManager(self._db, MagicMock()) from api.app import app app.manager = manager diff --git a/tests/api/admin/test_routes.py b/tests/api/admin/test_routes.py index 465aec0e7f..6dae466615 100644 --- a/tests/api/admin/test_routes.py +++ b/tests/api/admin/test_routes.py @@ -680,29 +680,6 @@ def test_process_search_service_self_tests(self, fixture: AdminRouteFixture): fixture.assert_supported_methods(url, "GET", "POST") -class TestAdminStorageServices: - CONTROLLER_NAME = "admin_storage_services_controller" - - @pytest.fixture(scope="function") - def fixture(self, admin_route_fixture: AdminRouteFixture) -> AdminRouteFixture: - admin_route_fixture.set_controller_name(self.CONTROLLER_NAME) - return admin_route_fixture - - def test_process_services(self, fixture: AdminRouteFixture): - url = "/admin/storage_services" - fixture.assert_authenticated_request_calls( - url, fixture.controller.process_services # type: ignore - ) - fixture.assert_supported_methods(url, "GET", "POST") - - def test_process_delete(self, fixture: AdminRouteFixture): - url = "/admin/storage_service/" - fixture.assert_authenticated_request_calls( - url, fixture.controller.process_delete, "", http_method="DELETE" # type: ignore - ) - fixture.assert_supported_methods(url, "DELETE") - - class TestAdminCatalogServices: CONTROLLER_NAME = "admin_catalog_services_controller" diff --git a/tests/api/test_google_analytics_provider.py b/tests/api/test_google_analytics_provider.py index ceee85fd6c..26682ceb9e 100644 --- a/tests/api/test_google_analytics_provider.py +++ b/tests/api/test_google_analytics_provider.py @@ -1,5 +1,6 @@ import unicodedata import urllib.parse +from unittest.mock import MagicMock import pytest from psycopg2.extras import NumericRange @@ -37,13 +38,13 @@ def test_init(self, db: DatabaseTransactionFixture): ) with pytest.raises(CannotLoadConfiguration) as excinfo: - GoogleAnalyticsProvider(integration) + GoogleAnalyticsProvider(integration, MagicMock()) assert "Google Analytics can't be configured without a library." in str( excinfo.value ) with pytest.raises(CannotLoadConfiguration) as excinfo: - GoogleAnalyticsProvider(integration, db.default_library()) + GoogleAnalyticsProvider(integration, MagicMock(), db.default_library()) assert ( "Missing tracking id for library %s" % db.default_library().short_name in str(excinfo.value) @@ -55,12 +56,12 @@ def test_init(self, db: DatabaseTransactionFixture): db.default_library(), integration, ).value = "faketrackingid" - ga = GoogleAnalyticsProvider(integration, db.default_library()) + ga = GoogleAnalyticsProvider(integration, MagicMock(), db.default_library()) assert GoogleAnalyticsProvider.DEFAULT_URL == ga.url assert "faketrackingid" == ga.tracking_id integration.url = db.fresh_str() - ga = GoogleAnalyticsProvider(integration, db.default_library()) + ga = GoogleAnalyticsProvider(integration, MagicMock(), db.default_library()) assert integration.url == ga.url assert "faketrackingid" == ga.tracking_id @@ -78,7 +79,7 @@ def test_collect_event_with_work(self, db: DatabaseTransactionFixture): db.default_library(), integration, ).value = "faketrackingid" - ga = MockGoogleAnalyticsProvider(integration, db.default_library()) + ga = MockGoogleAnalyticsProvider(integration, MagicMock(), db.default_library()) work = db.work( title="pi\u00F1ata", @@ -146,7 +147,7 @@ def test_collect_event_without_work(self, db: DatabaseTransactionFixture): db.default_library(), integration, ).value = "faketrackingid" - ga = MockGoogleAnalyticsProvider(integration, db.default_library()) + ga = MockGoogleAnalyticsProvider(integration, MagicMock(), db.default_library()) identifier = db.identifier() source = DataSource.lookup(db.session, DataSource.GUTENBERG) @@ -201,7 +202,7 @@ def test_collect_event_without_license_pool(self, db: DatabaseTransactionFixture db.default_library(), integration, ).value = "faketrackingid" - ga = MockGoogleAnalyticsProvider(integration, db.default_library()) + ga = MockGoogleAnalyticsProvider(integration, MagicMock(), db.default_library()) now = utc_now() ga.collect_event(db.default_library(), None, CirculationEvent.NEW_PATRON, now) diff --git a/tests/core/service/__init__.py b/tests/core/service/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/service/storage/__init__.py b/tests/core/service/storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/core/service/storage/test_configuration.py b/tests/core/service/storage/test_configuration.py new file mode 100644 index 0000000000..2c4d0ecbc5 --- /dev/null +++ b/tests/core/service/storage/test_configuration.py @@ -0,0 +1,44 @@ +import pytest + +from core.config import CannotLoadConfiguration +from core.service.storage.configuration import StorageConfiguration + + +def test_region_validation_fail(): + with pytest.raises(CannotLoadConfiguration) as exc_info: + StorageConfiguration(region="foo bar baz") + + assert "PALACE_STORAGE_REGION: Invalid region: foo bar baz." in str(exc_info.value) + + +def test_region_validation_success(): + configuration = StorageConfiguration(region="us-west-2") + assert configuration.region == "us-west-2" + + +@pytest.mark.parametrize( + "url", + [ + "http://localhost:9000", + "https://real.endpoint.com", + "http://192.168.0.1", + ], +) +def test_endpoint_url_validation_success(url: str): + configuration = StorageConfiguration(endpoint_url=url) + assert configuration.endpoint_url == url + + +@pytest.mark.parametrize( + "url, error", + [ + ("ftp://localhost:9000", "URL scheme not permitted"), + ("foo bar baz", "invalid or missing URL scheme"), + ], +) +def test_endpoint_url_validation_fail(url: str, error: str): + with pytest.raises(CannotLoadConfiguration) as exc_info: + StorageConfiguration(endpoint_url=url) + + assert "PALACE_STORAGE_ENDPOINT_URL" in str(exc_info.value) + assert error in str(exc_info.value) diff --git a/tests/core/test_local_analytics_provider.py b/tests/core/test_local_analytics_provider.py index 031b9a90d0..478e4d6d68 100644 --- a/tests/core/test_local_analytics_provider.py +++ b/tests/core/test_local_analytics_provider.py @@ -1,9 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest from core.local_analytics_provider import LocalAnalyticsProvider from core.model import CirculationEvent, ExternalIntegration, create, get_one from core.util.datetime_helpers import utc_now -from tests.fixtures.database import DatabaseTransactionFixture + +if TYPE_CHECKING: + from tests.fixtures.database import DatabaseTransactionFixture + from tests.fixtures.services import MockServicesFixture class TestInitializeLocalAnalyticsProvider: @@ -49,7 +56,11 @@ class LocalAnalyticsProviderFixture: integration: ExternalIntegration la: LocalAnalyticsProvider - def __init__(self, transaction: DatabaseTransactionFixture): + def __init__( + self, + transaction: DatabaseTransactionFixture, + mock_services_fixture: MockServicesFixture, + ): self.transaction = transaction self.integration, ignore = create( transaction.session, @@ -57,16 +68,17 @@ def __init__(self, transaction: DatabaseTransactionFixture): goal=ExternalIntegration.ANALYTICS_GOAL, protocol="core.local_analytics_provider", ) + self.services = mock_services_fixture.services self.la = LocalAnalyticsProvider( - self.integration, transaction.default_library() + self.integration, self.services, transaction.default_library() ) @pytest.fixture() def local_analytics_provider_fixture( - db, + db: DatabaseTransactionFixture, mock_services_fixture: MockServicesFixture ) -> LocalAnalyticsProviderFixture: - return LocalAnalyticsProviderFixture(db) + return LocalAnalyticsProviderFixture(db, mock_services_fixture) class TestLocalAnalyticsProvider: @@ -123,7 +135,7 @@ def test_collect_event( # It's possible to instantiate the LocalAnalyticsProvider # without a library. - la = LocalAnalyticsProvider(data.integration) + la = LocalAnalyticsProvider(data.integration, data.services) # In that case, it will process events for any library. for library in [database.default_library(), library2]: @@ -184,7 +196,7 @@ def test_neighborhood_is_location( data.integration.setting( p.LOCATION_SOURCE ).value = p.LOCATION_SOURCE_NEIGHBORHOOD - la = p(data.integration, database.default_library()) + la = p(data.integration, data.services, database.default_library()) event, is_new = la.collect_event( database.default_library(), diff --git a/tests/core/test_marc.py b/tests/core/test_marc.py index 699673a7f9..1c4c82f210 100644 --- a/tests/core/test_marc.py +++ b/tests/core/test_marc.py @@ -2,6 +2,7 @@ import datetime from typing import TYPE_CHECKING +from urllib.parse import quote import pytest from freezegun import freeze_time @@ -655,6 +656,7 @@ def test_records_lane( str(cache.representation.fetched_at), lane_or_wl.display_name, ) + assert quote(storage_service.uploads[0].key) in cache.representation.mirror_url assert cache.start_time is None assert marc_exporter_fixture.now < cache.end_time diff --git a/tests/core/test_s3_analytics_provider.py b/tests/core/test_s3_analytics_provider.py index a57f37dab8..9ed40e782a 100644 --- a/tests/core/test_s3_analytics_provider.py +++ b/tests/core/test_s3_analytics_provider.py @@ -37,7 +37,7 @@ def __init__( self.services = services_fixture.services self.analytics_storage = services_fixture.storage.analytics self.analytics_provider = S3AnalyticsProvider( - self.analytics_integration, db.default_library() + self.analytics_integration, self.services, db.default_library() ) @@ -65,11 +65,13 @@ def test_exception_is_raised_when_no_analytics_bucket_configured( self, s3_analytics_fixture: S3AnalyticsFixture ): # The services container returns None when there is no analytics storage service configured, - # we simulate this by directly passing None to the analytics provider. + # so we override the analytics storage service with None to simulate this situation. + s3_analytics_fixture.services.storage.analytics.override(None) + provider = S3AnalyticsProvider( s3_analytics_fixture.analytics_integration, + s3_analytics_fixture.services, s3_analytics_fixture.db.default_library(), - s3_service=None, ) # Act, Assert