Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the fixture for creating collections #2101

Merged
merged 1 commit into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/palace/manager/api/enki.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ class EnkiAPI(
HasCollectionSelfTests,
EnkiConstants,
):
ENKI_LIBRARY_ID_KEY = "enki_library_id"
DESCRIPTION = _("Integrate an Enki collection.")

list_endpoint = "ListAPI"
Expand Down
7 changes: 5 additions & 2 deletions src/palace/manager/sqlalchemy/model/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,12 @@ def explain(self, include_secrets: bool = False) -> list[str]:
def process_settings_dict(
settings_dict: dict[str, Any], indent: int = 0
) -> None:
secret_keys = ["key", "password", "token"]
secret_keys = ["key", "password", "token", "secret"]
for setting_key, setting_value in sorted(settings_dict.items()):
if setting_key in secret_keys and not include_secrets:
if (
any(secret_key in setting_key for secret_key in secret_keys)
and not include_secrets
):
setting_value = "********"
lines.append(" " * indent + f"{setting_key}: {setting_value}")

Expand Down
127 changes: 81 additions & 46 deletions tests/fixtures/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import importlib
import logging
import shutil
Expand All @@ -20,22 +21,32 @@
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy.engine import Connection, Engine, Transaction, make_url
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm.attributes import flag_modified
from typing_extensions import Self

from palace.manager.api.authentication.base import AuthenticationProvider
from palace.manager.api.authentication.base import SettingsType as TAuthProviderSettings
from palace.manager.api.circulation import (
BaseCirculationAPI,
BaseCirculationApiSettings,
)
from palace.manager.api.circulation import SettingsType as TCirculationSettings
from palace.manager.api.discovery.opds_registration import (
OpdsRegistrationService,
OpdsRegistrationServiceSettings,
)
from palace.manager.api.odl.api import OPDS2WithODLApi
from palace.manager.api.odl.settings import OPDS2WithODLSettings
from palace.manager.api.overdrive import OverdriveAPI, OverdriveSettings
from palace.manager.api.simple_authentication import (
SimpleAuthenticationProvider,
SimpleAuthSettings,
)
from palace.manager.core.classifier import Classifier
from palace.manager.core.config import Configuration
from palace.manager.core.exceptions import BasePalaceException, PalaceValueError
from palace.manager.core.opds_import import OPDSAPI
from palace.manager.core.opds2_import import OPDS2API
from palace.manager.core.opds_import import OPDSAPI, OPDSImporterSettings
from palace.manager.integration.base import (
HasIntegrationConfiguration,
HasLibraryIntegrationConfiguration,
Expand Down Expand Up @@ -408,9 +419,8 @@ def _make_default_library(self) -> Library:
library = self.library("default", "default")
collection = self.collection(
"Default Collection",
protocol=OPDSAPI.label(),
data_source_name="OPDS",
external_account_id="http://opds.example.com/feed",
protocol=OPDSAPI,
settings=self.opds_settings(data_source="OPDS"),
)
collection.libraries.append(library)
return library
Expand Down Expand Up @@ -530,34 +540,73 @@ def library(
)
return library

opds_settings = functools.partial(
OPDSImporterSettings,
external_account_id="http://opds.example.com/feed",
data_source="OPDS",
)

overdrive_settings = functools.partial(
OverdriveSettings,
external_account_id="library_id",
overdrive_website_id="website_id",
overdrive_client_key="client_key",
overdrive_client_secret="client_secret",
overdrive_server_nickname="production",
)

opds2_odl_settings = functools.partial(
OPDS2WithODLSettings,
username="username",
password="password",
external_account_id="http://example.com/feed",
data_source=DataSource.FEEDBOOKS,
)

def collection_settings(
self, protocol: type[BaseCirculationAPI[TCirculationSettings, Any]]
) -> TCirculationSettings | None:
if protocol in [OPDSAPI, OPDS2API]:
return self.opds_settings() # type: ignore[return-value]
elif protocol == OverdriveAPI:
return self.overdrive_settings() # type: ignore[return-value]
elif protocol == OPDS2WithODLApi:
return self.opds2_odl_settings() # type: ignore[return-value]
return None

def collection(
self,
name=None,
protocol=OPDSAPI.label(),
external_account_id=None,
url=None,
username=None,
password=None,
data_source_name=None,
settings: dict[str, Any] | None = None,
name: str | None = None,
*,
protocol: type[BaseCirculationAPI[Any, Any]] | str = OPDSAPI,
settings: BaseCirculationApiSettings | dict[str, Any] | None = None,
library: Library | None = None,
) -> Collection:
name = name or self.fresh_str()
collection, _ = Collection.by_name_and_protocol(self.session, name, protocol)
settings = settings or {}
if url:
settings["url"] = url
if username:
settings["username"] = username
if password:
settings["password"] = password
if external_account_id:
settings["external_account_id"] = external_account_id
collection.integration_configuration.settings_dict = settings

if data_source_name:
collection.data_source = data_source_name
if library:
protocol_str = (
protocol
if isinstance(protocol, str)
else self._goal_registry_mapping[Goals.LICENSE_GOAL].get_protocol(protocol)
)
assert protocol_str is not None
collection, _ = Collection.by_name_and_protocol(
self.session, name, protocol_str
)

if settings is None and not isinstance(protocol, str):
settings = self.collection_settings(protocol)

if isinstance(settings, BaseCirculationApiSettings):
if isinstance(protocol, str):
raise PalaceValueError(
"protocol must be a subclass of BaseCirculationAPI to set settings"
)
protocol.settings_update(collection.integration_configuration, settings)
elif isinstance(settings, dict):
collection.integration_configuration.settings_dict = settings
flag_modified(collection.integration_configuration, "settings_dict")

if library and library not in collection.libraries:
collection.libraries.append(library)
return collection

Expand Down Expand Up @@ -933,6 +982,11 @@ def _goal_registry_mapping(self) -> Mapping[Goals, IntegrationRegistry[Any]]:
Goals.PATRON_AUTH_GOAL: self._services.services.integration_registry.patron_auth(),
}

def protocol_string(
self, goal: Goals, protocol: type[BaseCirculationAPI[Any, Any]]
) -> str:
return self._goal_registry_mapping[goal].get_protocol(protocol, False)

def integration_configuration(
self,
protocol: type[HasIntegrationConfiguration[TIntegrationSettings]] | str,
Expand Down Expand Up @@ -1046,25 +1100,6 @@ def simple_auth_integration(
),
)

@classmethod
def set_settings(
cls,
config: IntegrationConfiguration | IntegrationLibraryConfiguration,
*keyvalues,
**kwargs,
):
settings = config.settings_dict.copy()

# Alternating key: value in the args
for ix, item in enumerate(keyvalues):
if ix % 2 == 0:
key = item
else:
settings[key] = item

settings.update(kwargs)
config.settings_dict = settings

def work_coverage_record(
self, work, operation=None, status=CoverageRecord.SUCCESS
) -> WorkCoverageRecord:
Expand Down
26 changes: 9 additions & 17 deletions tests/fixtures/odl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,18 @@ def __init__(
def create_work(self, collection: Collection) -> Work:
return self.db.work(with_license_pool=True, collection=collection)

@staticmethod
def default_collection_settings() -> dict[str, Any]:
return {
"username": "a",
"password": "b",
"external_account_id": "http://odl",
Collection.DATA_SOURCE_NAME_SETTING: "Feedbooks",
}

def create_collection(self, library: Library) -> Collection:
collection, _ = Collection.by_name_and_protocol(
self.db.session,
return self.db.collection(
f"Test {OPDS2WithODLApi.__name__} Collection",
OPDS2WithODLApi.label(),
)
collection.integration_configuration.settings_dict = (
self.default_collection_settings()
protocol=OPDS2WithODLApi,
library=library,
settings=self.db.opds2_odl_settings(
username="a",
password="b",
external_account_id="http://odl",
data_source="Feedbooks",
),
)
collection.libraries.append(library)
return collection

def setup_license(
self,
Expand Down
Loading