From 37d72b4c74c6665c06aaab31a03554ba89f8e27a Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Thu, 2 May 2024 10:58:10 -0300 Subject: [PATCH] Fix mypy and add better comments. --- src/palace/manager/scripts.py | 9 +- tests/fixtures/database.py | 229 ++++++++++++------ tests/fixtures/search.py | 4 +- .../api/controller/test_scopedsession.py | 28 ++- tests/manager/api/test_scripts.py | 24 +- tests/migration/conftest.py | 4 +- tests/migration/test_instance_init_script.py | 43 ++-- 7 files changed, 218 insertions(+), 123 deletions(-) diff --git a/src/palace/manager/scripts.py b/src/palace/manager/scripts.py index a4682d55c9..46fbd8ae0b 100644 --- a/src/palace/manager/scripts.py +++ b/src/palace/manager/scripts.py @@ -4,7 +4,7 @@ import os import sys import time -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import timedelta from pathlib import Path from typing import Any @@ -12,7 +12,7 @@ from alembic import command, config from alembic.util import CommandError from sqlalchemy import inspect, select -from sqlalchemy.engine import Connection +from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Session @@ -489,6 +489,7 @@ class InstanceInitializationScript: def __init__( self, config_file: Path | None = None, + engine_factory: Callable[[], Engine] = SessionManager.engine, ) -> None: self._log: logging.Logger | None = None self._container = container_instance() @@ -497,6 +498,8 @@ def __init__( self._container.init_resources() self._config_file = config_file + self._engine_factory = engine_factory + @property def log(self) -> logging.Logger: if self._log is None: @@ -571,7 +574,7 @@ def run(self) -> None: instance of the script is running at a time. This prevents multiple instances from trying to initialize the database at the same time. """ - engine = SessionManager.engine() + engine = self._engine_factory() with engine.begin() as connection: with pg_advisory_lock(connection, LOCK_ID_DB_INIT): self.initialize(connection) diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index a24dafcd32..a0eda98df5 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -18,7 +18,7 @@ from Crypto.PublicKey.RSA import import_key from pydantic import PostgresDsn from sqlalchemy import MetaData, create_engine, text -from sqlalchemy.engine import Connection, Transaction, make_url +from sqlalchemy.engine import Connection, Engine, Transaction, make_url from sqlalchemy.orm import Session, sessionmaker from typing_extensions import Self @@ -62,36 +62,59 @@ from palace.manager.sqlalchemy.model.patron import Patron from palace.manager.sqlalchemy.model.resource import Hyperlink, Representation from palace.manager.sqlalchemy.model.work import Work -from palace.manager.sqlalchemy.session import SessionManager +from palace.manager.sqlalchemy.session import SessionManager, json_serializer from palace.manager.sqlalchemy.util import create, get_one_or_create from palace.manager.util.datetime_helpers import utc_now class TestIdFixture: """ - This fixture creates a unique per test run worker id. Each test worker will get its own - ID, so its suitable for initializing shared resources that need to be unique per test worker. + This fixture creates a unique test id. This ID is suitable for initializing shared resources. For example - database name, opensearch index name, etc. """ def __init__(self, worker_id: str, prefix: str): - self.worker_id = worker_id + # worker_id comes from the pytest-xdist fixture + self._worker_id = worker_id + # This flag indicates that the tests are running in parallel mode. self.parallel = worker_id != "master" + # We create a unique run id for each test run. The dashes are # replaced with underscores to make it a valid identifier for # in PostgreSQL. self.run_id = str(uuid.uuid4()).replace("-", "_") - self.id = f"{prefix}_{self.worker_id}_{self.run_id}" + + self._prefix = prefix + + @cached_property + def id(self) -> str: + # The test id is a combination of the prefix, worker id and run id. + # This is the ID that should be used to create unique resources. + return f"{self._prefix}_{self._worker_id}_{self.run_id}" @pytest.fixture(scope="session") -def test_id(worker_id: str) -> TestIdFixture: +def session_test_id(worker_id: str) -> TestIdFixture: + """ + This is a session scoped fixture that provides a unique test id. Since session scoped fixtures + are created only once per worker, per test run, this fixture provides a unique test ID that is + stable for the worker for the entire test run. + + This is useful when initializing session scoped shared resources like databases, opensearch indexes, etc. + """ return TestIdFixture(worker_id, "session") @pytest.fixture(scope="function") -def test_id_func(worker_id: str) -> TestIdFixture: +def function_test_id(worker_id: str) -> TestIdFixture: + """ + This is a function scoped fixture that provides a unique test id. Since function scoped fixtures + are created for each test function, this fixture provides a unique test ID that for each test function. + + This is useful when initializing function scoped shared resources. + """ + return TestIdFixture(worker_id, "function") @@ -103,7 +126,15 @@ class Config: env_prefix = "PALACE_TEST_DATABASE_" -class DatabaseNameFixture: +class DatabaseCreationFixture: + """ + Uses the configured database URL to create a unique database for each test run. The database + is dropped after the test run is complete. + + Database creation can be disabled by setting the `create_database` flag to False in the configuration. + In this case the database URL is used as is. + """ + def __init__(self, test_id: TestIdFixture): self.test_id = test_id config = DatabaseTestConfiguration() @@ -113,22 +144,42 @@ def __init__(self, test_id: TestIdFixture): "This is not supported. Please enable database creation or run tests in serial mode." ) self.create_database = config.create_database - self.main_url = make_url(config.url) + self._config_url = make_url(config.url) @cached_property def database_name(self) -> str: + """ + Returns the name of the database that the test should use. + """ + if not self.create_database: - return self.main_url.database + if self._config_url.database is None: + raise BasePalaceException( + "Database name is required when database creation is disabled." + ) + return self._config_url.database return self.test_id.id @cached_property - def worker_url(self) -> str: - return str(self.main_url.set(database=self.database_name)) + def url(self) -> str: + """ + Returns the Postgres URL for the database that the test should use. This URL + includes credentials and the database name, so it has everything needed to + connect to the database. + """ + + return str(self._config_url.set(database=self.database_name)) @contextmanager def _db_connection(self) -> Generator[Connection, None, None]: - engine = create_engine(self.main_url, isolation_level="AUTOCOMMIT") + """ + Databases need to be created and dropped outside a transaction. This method + provides a connection to database URL provided in the configuration that is not + wrapped in a transaction. + """ + + engine = create_engine(self._config_url, isolation_level="AUTOCOMMIT") connection = engine.connect() try: yield connection @@ -136,18 +187,18 @@ def _db_connection(self) -> Generator[Connection, None, None]: connection.close() engine.dispose() - def create(self) -> None: + def _create_db(self) -> None: if not self.create_database: return with self._db_connection() as connection: - user = self.main_url.username + user = self._config_url.username connection.execute(text(f"CREATE DATABASE {self.database_name}")) connection.execute( text(f"GRANT ALL PRIVILEGES ON DATABASE {self.database_name} TO {user}") ) - def drop(self) -> None: + def _drop_db(self) -> None: if not self.create_database: return @@ -158,28 +209,41 @@ def drop(self) -> None: @contextmanager def fixture(cls, test_id: TestIdFixture) -> Generator[Self, None, None]: db_name_fixture = cls(test_id) - db_name_fixture.create() + db_name_fixture._create_db() try: # Patch the database URL, so any code that uses it will use the worker specific database. with patch.object( - Configuration, "database_url", return_value=db_name_fixture.worker_url + Configuration, "database_url", return_value=db_name_fixture.url ): yield db_name_fixture finally: - db_name_fixture.drop() + db_name_fixture._drop_db() @pytest.fixture(scope="session") -def database_name(test_id: TestIdFixture) -> Generator[DatabaseNameFixture, None, None]: - with DatabaseNameFixture.fixture(test_id) as fixture: +def database_creation( + session_test_id: TestIdFixture, +) -> Generator[DatabaseCreationFixture, None, None]: + """ + This is a session scoped fixture that provides a unique database for each worker in the test run. + """ + + with DatabaseCreationFixture.fixture(session_test_id) as fixture: yield fixture @pytest.fixture(scope="function") -def database_name_func( - test_id_func: TestIdFixture, -) -> Generator[DatabaseNameFixture, None, None]: - with DatabaseNameFixture.fixture(test_id_func) as fixture: +def function_database_creation( + function_test_id: TestIdFixture, +) -> Generator[DatabaseCreationFixture, None, None]: + """ + This is a function scoped fixture that provides a unique database for each test function. + + This is resource intensive, so it should only be used when necessary. It is helpful + when testing database interactions that cannot be tested in a transaction. Such as + instance initialization, schema migrations, etc. + """ + with DatabaseCreationFixture.fixture(function_test_id) as fixture: if not fixture.create_database: raise BasePalaceException( "Cannot provide a function scoped database when database creation is disabled." @@ -188,48 +252,71 @@ def database_name_func( class DatabaseFixture: - """The DatabaseFixture stores a reference to the database.""" - - # We store a reference to SessionManager.engine so that we can patch it in tests - # and still access the function from the fixture. - engine_func = SessionManager.engine + """ + The DatabaseFixture initializes the database schema and creates a connection to the database + that should be used in the tests. + """ - def __init__(self, database_name: DatabaseNameFixture) -> None: + def __init__(self, database_name: DatabaseCreationFixture) -> None: self.database_name = database_name - self.engine = self.engine_func(url=self.database_name.worker_url) + self.engine = self.engine_factory() self.connection = self.engine.connect() + @staticmethod + def create_engine(url: str) -> Engine: + return create_engine( + url, + json_serializer=json_serializer, + ) + + def engine_factory(self) -> Engine: + return self.create_engine(self.database_name.url) + def drop_existing_schema(self) -> None: metadata_obj = MetaData() metadata_obj.reflect(bind=self.engine) metadata_obj.drop_all(self.engine) metadata_obj.clear() - def initialize_database(self) -> None: + def _initialize_database(self) -> None: SessionManager.initialize_schema(self.connection) with Session(self.connection) as session: # Initialize the database with default data SessionManager.initialize_data(session) @staticmethod - def load_model_classes(): - # Load all the core model classes so that they are registered with the ORM. + def _load_model_classes(): + """ + Make sure that all the model classes are loaded, so that they are registered with the + ORM when we are creating the schema. + """ import palace.manager.sqlalchemy.model importlib.reload(palace.manager.sqlalchemy.model) - def close(self): + def _close(self): # Destroy the database connection and engine. self.connection.close() self.engine.dispose() + @contextmanager + def patch_engine(self) -> Generator[None, None, None]: + """ + This method patches the SessionManager to use the engine provided by this fixture. + This is useful when the tests need to access the engine directly. + """ + with patch.object(SessionManager, "engine", return_value=self.engine): + yield + @classmethod @contextmanager - def fixture(cls, database_name: DatabaseNameFixture) -> Generator[Self, None, None]: + def fixture( + cls, database_name: DatabaseCreationFixture + ) -> Generator[Self, None, None]: db_fixture = cls(database_name) db_fixture.drop_existing_schema() - db_fixture.load_model_classes() - db_fixture.initialize_database() + db_fixture._load_model_classes() + db_fixture._initialize_database() try: # Patch the SessionManager to make sure tests are not trying to access the engine directly. with patch.object( @@ -241,44 +328,40 @@ def fixture(cls, database_name: DatabaseNameFixture) -> Generator[Self, None, No ): yield db_fixture finally: - db_fixture.close() + db_fixture._close() @pytest.fixture(scope="session") def database( - database_name: DatabaseNameFixture, + database_creation: DatabaseCreationFixture, ) -> Generator[DatabaseFixture, None, None]: - with DatabaseFixture.fixture(database_name) as db: + """ + This is a session scoped fixture that provides a unique database engine and connection + for each worker in the test run. + """ + with DatabaseFixture.fixture(database_creation) as db: yield db @pytest.fixture(scope="function") -def database_func( - database_name_func: DatabaseNameFixture, +def function_database( + function_database_creation: DatabaseCreationFixture, ) -> Generator[DatabaseFixture, None, None]: - with DatabaseFixture.fixture(database_name_func) as db: + """ + This is a function scoped fixture that provides a unique database engine and connection + for each test. This is resource intensive, so it should only be used when necessary. + """ + with DatabaseFixture.fixture(function_database_creation) as db: yield db class DatabaseTransactionFixture: """A fixture representing a single transaction. The transaction is automatically rolled back.""" - _database: DatabaseFixture - _default_library: Library | None - _default_collection: Collection | None - _session: Session - _transaction: Transaction - _counter: int - _isbns: list[str] - - def __init__( - self, database: DatabaseFixture, session: Session, transaction: Transaction - ): + def __init__(self, database: DatabaseFixture): self._database = database - self._session = session - self._transaction = transaction - self._default_library = None - self._default_collection = None + self._default_library: Library | None = None + self._default_collection: Collection | None = None self._counter = 2000 self._isbns = [ "9780674368279", @@ -286,6 +369,8 @@ def __init__( "9781936460236", "9780316075978", ] + self._session = SessionManager.session_from_connection(database.connection) + self._transaction = database.connection.begin_nested() def _make_default_library(self) -> Library: """Ensure that the default library exists in the given database.""" @@ -299,15 +384,16 @@ def _make_default_library(self) -> Library: collection.libraries.append(library) return library - @staticmethod - def create(database: DatabaseFixture) -> DatabaseTransactionFixture: - # Create a new connection to the database. - session = SessionManager.session_from_connection(database.connection) - - transaction = database.connection.begin_nested() - return DatabaseTransactionFixture(database, session, transaction) + @classmethod + @contextmanager + def fixture(cls, database: DatabaseFixture) -> Generator[Self, None, None]: + db = cls(database) + try: + yield db + finally: + db._close() - def close(self): + def _close(self): # Close the session. self._session.close() @@ -998,9 +1084,8 @@ def credential(self, data_source_name=DataSource.GUTENBERG, type=None, patron=No def db( database: DatabaseFixture, ) -> Generator[DatabaseTransactionFixture, None, None]: - tr = DatabaseTransactionFixture.create(database) - yield tr - tr.close() + with DatabaseTransactionFixture.fixture(database) as db: + yield db class TemporaryDirectoryConfigurationFixture: diff --git a/tests/fixtures/search.py b/tests/fixtures/search.py index 4936951e8e..5626fa7fe6 100644 --- a/tests/fixtures/search.py +++ b/tests/fixtures/search.py @@ -96,12 +96,12 @@ def fixture( def external_search_fixture( db: DatabaseTransactionFixture, services_fixture: ServicesFixture, - test_id_func: TestIdFixture, + function_test_id: TestIdFixture, ) -> Generator[ExternalSearchFixture, None, None]: """Ask for an external search system.""" """Note: You probably want EndToEndSearchFixture instead.""" with ExternalSearchFixture.fixture( - db, services_fixture.services, test_id_func + db, services_fixture.services, function_test_id ) as fixture: yield fixture diff --git a/tests/manager/api/controller/test_scopedsession.py b/tests/manager/api/controller/test_scopedsession.py index 6aefc8257e..b07e68493a 100644 --- a/tests/manager/api/controller/test_scopedsession.py +++ b/tests/manager/api/controller/test_scopedsession.py @@ -1,6 +1,6 @@ from collections.abc import Generator from contextlib import contextmanager -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import flask import pytest @@ -11,7 +11,6 @@ from palace.manager.sqlalchemy.flask_sqlalchemy_session import current_session from palace.manager.sqlalchemy.model.datasource import DataSource from palace.manager.sqlalchemy.model.identifier import Identifier -from palace.manager.sqlalchemy.session import SessionManager from tests.fixtures.database import DatabaseFixture from tests.fixtures.services import ServicesFixture from tests.mocks.circulation import MockCirculationManager @@ -25,13 +24,13 @@ def __init__( self.session = session self.services = services self.app = app - with (patch.object(SessionManager, "engine", return_value=db_fixture.engine),): + with db_fixture.patch_engine(): initialize_database() self.app.manager = MockCirculationManager(app._db, services.services) self.mock_library = MagicMock() self.mock_library.has_root_lanes = False - def cleanup(self) -> None: + def _cleanup(self) -> None: delattr(self.app, "manager") delattr(self.app, "_db") @@ -43,14 +42,20 @@ def fixture( with Session(db_fixture.connection) as session: fixture = cls(db_fixture, services_fixture, session) yield fixture - fixture.cleanup() + fixture._cleanup() + + @contextmanager + def request_context(self, path: str) -> Generator[None, None, None]: + with self.app.test_request_context(path) as ctx: + ctx.request.library = self.mock_library # type: ignore[attr-defined] + yield @pytest.fixture def scoped_session_fixture( - database_func: DatabaseFixture, services_fixture: ServicesFixture + function_database: DatabaseFixture, services_fixture: ServicesFixture ) -> Generator[ScopedSessionFixture, None, None]: - with ScopedSessionFixture.fixture(database_func, services_fixture) as fixture: + with ScopedSessionFixture.fixture(function_database, services_fixture) as fixture: yield fixture @@ -68,7 +73,7 @@ def test_scoped_session( scoped_session_fixture: ScopedSessionFixture, ): # Start a simulated request to the Flask app server. - with scoped_session_fixture.app.test_request_context("/"): + with scoped_session_fixture.request_context("/"): # Each request is given its own database session distinct # from the one used by most unit tests and the one created # outside of this context. @@ -109,13 +114,12 @@ def test_scoped_session( # When the index controller runs in the request context, # it doesn't store anything that's associated with the # scoped session. - flask.request.library = scoped_session_fixture.mock_library response = app.manager.index_controller() assert 302 == response.status_code # Once we exit the context of the Flask request, the # transaction is committed and the Identifier is written to the - # database. That is why we run this test with the database_func + # database. That is why we run this test with the function_database # fixture, which gives us a function scoped database to work with. # This database is removed after the test completes, so we don't # have to worry about cleaning up the database after ourselves. @@ -123,7 +127,7 @@ def test_scoped_session( assert "1024" == identifier.identifier # Now create a different simulated Flask request - with app.test_request_context("/"): + with scoped_session_fixture.request_context("/"): session2 = current_session() assert session2 != scoped_session_fixture.session assert session2 != app.manager._db @@ -131,7 +135,7 @@ def test_scoped_session( # The controller still works in the new request context - # nothing it needs is associated with the previous scoped # session. - flask.request.library = scoped_session_fixture.mock_library + flask.request.library = scoped_session_fixture.mock_library # type: ignore[attr-defined] response = app.manager.index_controller() assert 302 == response.status_code diff --git a/tests/manager/api/test_scripts.py b/tests/manager/api/test_scripts.py index 743018dfa5..317aed56cf 100644 --- a/tests/manager/api/test_scripts.py +++ b/tests/manager/api/test_scripts.py @@ -541,18 +541,18 @@ class TestInstanceInitializationScript: def test_run_locks_database(self, db: DatabaseTransactionFixture): # The script locks the database with a PostgreSQL advisory lock - with patch("palace.manager.scripts.SessionManager") as session_manager: - with patch("palace.manager.scripts.pg_advisory_lock") as advisory_lock: - script = InstanceInitializationScript() - script.initialize = MagicMock() - script.run() - - advisory_lock.assert_called_once_with( - session_manager.engine().begin().__enter__(), - LOCK_ID_DB_INIT, - ) - advisory_lock().__enter__.assert_called_once() - advisory_lock().__exit__.assert_called_once() + with patch("palace.manager.scripts.pg_advisory_lock") as advisory_lock: + mock_engine_factory = MagicMock() + script = InstanceInitializationScript(engine_factory=mock_engine_factory) + script.initialize = MagicMock() + script.run() + + advisory_lock.assert_called_once_with( + mock_engine_factory().begin().__enter__(), + LOCK_ID_DB_INIT, + ) + advisory_lock().__enter__.assert_called_once() + advisory_lock().__exit__.assert_called_once() def test_initialize(self, db: DatabaseTransactionFixture): # Test that the script inspects the database and initializes or migrates the database diff --git a/tests/migration/conftest.py b/tests/migration/conftest.py index b3a37652d1..392c5fee4d 100644 --- a/tests/migration/conftest.py +++ b/tests/migration/conftest.py @@ -35,11 +35,11 @@ def alembic_config(alembic_config_path: Path) -> Config: @pytest.fixture -def alembic_engine(database_func: DatabaseFixture) -> Engine: +def alembic_engine(function_database: DatabaseFixture) -> Engine: """ Override this fixture to provide pytest-alembic powered tests with a database handle. """ - return database_func.engine + return function_database.engine @pytest.fixture diff --git a/tests/migration/test_instance_init_script.py b/tests/migration/test_instance_init_script.py index fb5546373b..9e67360ed4 100644 --- a/tests/migration/test_instance_init_script.py +++ b/tests/migration/test_instance_init_script.py @@ -5,16 +5,15 @@ from io import StringIO from multiprocessing import Process from pathlib import Path -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock import pytest from pytest_alembic import MigrationContext from sqlalchemy import inspect +from sqlalchemy.engine import Engine from typing_extensions import Self -from palace.manager.core.config import Configuration from palace.manager.scripts import InstanceInitializationScript -from palace.manager.sqlalchemy.session import SessionManager from tests.fixtures.database import DatabaseFixture from tests.fixtures.services import ServicesFixture, mock_services_container @@ -22,54 +21,58 @@ class InstanceInitScriptFixture: def __init__( self, - database_func: DatabaseFixture, + function_database: DatabaseFixture, services_fixture: ServicesFixture, alembic_config_path: Path, ): - self.database = database_func + self.database = function_database self.services = services_fixture self.alembic_config_path = alembic_config_path def script(self) -> InstanceInitializationScript: - return InstanceInitializationScript(config_file=self.alembic_config_path) + return InstanceInitializationScript( + config_file=self.alembic_config_path, + engine_factory=self.database.engine_factory, + ) @classmethod @contextmanager def fixture( cls, - database_func: DatabaseFixture, + function_database: DatabaseFixture, services_fixture: ServicesFixture, alembic_config_path: Path, ) -> Generator[Self, None, None]: - fixture = cls(database_func, services_fixture, alembic_config_path) - with patch.object(SessionManager, "engine", fixture.database.engine_func): - yield fixture + fixture = cls(function_database, services_fixture, alembic_config_path) + yield fixture @pytest.fixture def instance_init_script_fixture( - database_func: DatabaseFixture, + function_database: DatabaseFixture, services_fixture: ServicesFixture, alembic_config_path: Path, ) -> Generator[InstanceInitScriptFixture, None, None]: with InstanceInitScriptFixture.fixture( - database_func, services_fixture, alembic_config_path + function_database, services_fixture, alembic_config_path ) as fixture: yield fixture -def _run_script(config_path: Path, worker_url: str) -> None: +def _run_script(config_path: Path, db_url: str) -> None: try: # Capturing the log output stream = StringIO() logging.basicConfig(stream=stream, level=logging.INFO, force=True) + def engine_factory() -> Engine: + return DatabaseFixture.create_engine(db_url) + mock_services = MagicMock() - with ( - mock_services_container(mock_services), - patch.object(Configuration, "database_url", return_value=worker_url), - ): - script = InstanceInitializationScript(config_file=config_path) + with (mock_services_container(mock_services),): + script = InstanceInitializationScript( + config_file=config_path, engine_factory=engine_factory + ) script.run() # Set our exit code to the number of upgrades we ran @@ -88,7 +91,7 @@ def test_locking( ) -> None: # Migrate to the initial revision alembic_runner.migrate_down_to("base") - worker_url = instance_init_script_fixture.database.database_name.worker_url + db_url = instance_init_script_fixture.database.database_name.url # Spawn three processes, that will all try to migrate to head # at the same time. One of them should do the migration, and @@ -96,7 +99,7 @@ def test_locking( # has already been done. process_kwargs = { "config_path": alembic_config_path, - "worker_url": worker_url, + "db_url": db_url, } p1 = Process(target=_run_script, kwargs=process_kwargs) p2 = Process(target=_run_script, kwargs=process_kwargs)