Skip to content

Commit

Permalink
Fix parallel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathangreen committed May 1, 2024
1 parent f19dbb7 commit 993aa79
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 29 deletions.
9 changes: 3 additions & 6 deletions src/palace/manager/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import os
import sys
import time
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from datetime import timedelta
from pathlib import Path
from typing import Any

from alembic import command, config
from alembic.util import CommandError
from sqlalchemy import inspect, select
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine import Connection
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -489,7 +489,6 @@ class InstanceInitializationScript:
def __init__(
self,
config_file: Path | None = None,
engine_factory: Callable[[], Engine] | None = None,
) -> None:
self._log: logging.Logger | None = None
self._container = container_instance()
Expand All @@ -498,8 +497,6 @@ def __init__(
self._container.init_resources()
self._config_file = config_file

self._engine_factory = engine_factory or SessionManager.engine

@property
def log(self) -> logging.Logger:
if self._log is None:
Expand Down Expand Up @@ -574,7 +571,7 @@ def run(self) -> None:
instance of the script is running at a time. This prevents multiple
instances from trying to initialize the database at the same time.
"""
engine = self._engine_factory()
engine = SessionManager.engine()
with engine.begin() as connection:
with pg_advisory_lock(connection, LOCK_ID_DB_INIT):
self.initialize(connection)
Expand Down
5 changes: 1 addition & 4 deletions tests/fixtures/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from Crypto.PublicKey.RSA import import_key
from pydantic import PostgresDsn
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy.engine import Connection, Engine, Transaction, make_url
from sqlalchemy.engine import Connection, Transaction, make_url
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import Self

Expand Down Expand Up @@ -211,9 +211,6 @@ def initialize_database(self) -> None:
# Initialize the database with default data
SessionManager.initialize_data(session)

def engine_factory(self) -> Engine:
return self.engine

@staticmethod
def load_model_classes():
# Load all the core model classes so that they are registered with the ORM.
Expand Down
93 changes: 74 additions & 19 deletions tests/migration/test_instance_init_script.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,75 @@
import logging
import sys
from collections.abc import Generator
from contextlib import contextmanager
from io import StringIO
from multiprocessing import Process
from pathlib import Path
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock, Mock, patch

import pytest
from pytest_alembic import MigrationContext
from sqlalchemy import inspect
from sqlalchemy.engine import Engine
from typing_extensions import Self

from palace.manager.core.config import Configuration
from palace.manager.scripts import InstanceInitializationScript
from tests.fixtures.database import DatabaseFixture
from tests.fixtures.services import mock_services_container
from palace.manager.sqlalchemy.session import SessionManager
from tests.fixtures.database import DatabaseFixture, DatabaseNameFixture
from tests.fixtures.services import ServicesFixture, mock_services_container


class InstanceInitScriptFixture:
def __init__(
self,
database_func: DatabaseFixture,
services_fixture: ServicesFixture,
alembic_config_path: Path,
):
self.database = database_func
self.services = services_fixture
self.alembic_config_path = alembic_config_path

def script(self) -> InstanceInitializationScript:
return InstanceInitializationScript(config_file=self.alembic_config_path)

@classmethod
@contextmanager
def fixture(
cls,
database_func: DatabaseFixture,
services_fixture: ServicesFixture,
alembic_config_path: Path,
) -> Generator[Self, None, None]:
fixture = cls(database_func, services_fixture, alembic_config_path)
with patch.object(SessionManager, "engine", fixture.database.engine_func):
yield fixture


@pytest.fixture
def instance_init_script_fixture(
database_func: DatabaseFixture,
services_fixture: ServicesFixture,
alembic_config_path: Path,
) -> Generator[InstanceInitScriptFixture, None, None]:
with InstanceInitScriptFixture.fixture(
database_func, services_fixture, alembic_config_path
) as fixture:
yield fixture


def _run_script() -> None:
def _run_script(config_path: Path, worker_url: str) -> None:
try:
# Capturing the log output
stream = StringIO()
logging.basicConfig(stream=stream, level=logging.INFO, force=True)

mock_services = MagicMock()
with mock_services_container(mock_services):
script = InstanceInitializationScript()
with (
mock_services_container(mock_services),
patch.object(Configuration, "database_url", return_value=worker_url),
):
script = InstanceInitializationScript(config_file=config_path)
script.run()

# Set our exit code to the number of upgrades we ran
Expand All @@ -34,17 +81,25 @@ def _run_script() -> None:
sys.exit(-1)


def test_locking(alembic_runner: MigrationContext, alembic_engine: Engine) -> None:
def test_locking(
alembic_runner: MigrationContext,
alembic_config_path: Path,
database_name_func: DatabaseNameFixture,
) -> None:
# Migrate to the initial revision
alembic_runner.migrate_down_to("base")

# Spawn three processes, that will all try to migrate to head
# at the same time. One of them should do the migration, and
# the other two should wait, then do no migration since it
# has already been done.
p1 = Process(target=_run_script)
p2 = Process(target=_run_script)
p3 = Process(target=_run_script)
process_kwargs = {
"config_path": alembic_config_path,
"worker_url": database_name_func.worker_url,
}
p1 = Process(target=_run_script, kwargs=process_kwargs)
p2 = Process(target=_run_script, kwargs=process_kwargs)
p3 = Process(target=_run_script, kwargs=process_kwargs)

p1.start()
p2.start()
Expand All @@ -67,14 +122,17 @@ def test_locking(alembic_runner: MigrationContext, alembic_engine: Engine) -> No
assert exit_codes[2] == 0


def test_initialize(database_func: DatabaseFixture) -> None:
def test_initialize(instance_init_script_fixture: InstanceInitScriptFixture) -> None:
# Drop any existing schema
instance_init_script_fixture.database.drop_existing_schema()

# Run the script and make sure we create the alembic_version table
engine = database_func.engine
engine = instance_init_script_fixture.database.engine
inspector = inspect(engine)
assert "alembic_version" not in inspector.get_table_names()
assert len(inspector.get_table_names()) == 0

script = InstanceInitializationScript()
script = instance_init_script_fixture.script()
script.initialize_database = Mock(wraps=script.initialize_database)
script.migrate_database = Mock(wraps=script.migrate_database)
script.run()
Expand All @@ -96,8 +154,7 @@ def test_initialize(database_func: DatabaseFixture) -> None:

def test_migrate(
alembic_runner: MigrationContext,
database_func: DatabaseFixture,
alembic_config_path: Path,
instance_init_script_fixture: InstanceInitScriptFixture,
) -> None:
# Run the script and make sure we create the alembic_version table
# Migrate to the initial revision
Expand All @@ -107,9 +164,7 @@ def test_migrate(
assert alembic_runner.current == "base"
assert alembic_runner.current != alembic_runner.heads[0]

script = InstanceInitializationScript(
config_file=alembic_config_path, engine_factory=database_func.engine_factory
)
script = instance_init_script_fixture.script()
script.initialize_database = Mock(wraps=script.initialize_database)
script.migrate_database = Mock(wraps=script.migrate_database)
script.run()
Expand Down

0 comments on commit 993aa79

Please sign in to comment.