Skip to content

Commit

Permalink
refactor: use Depends for settings and session_local (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
fspoettel authored Jan 1, 2024
1 parent 557de5a commit 3559aa5
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 120 deletions.
10 changes: 3 additions & 7 deletions app/shared/celery.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from celery import Celery

from app.shared.settings import settings


def get_celery_binding() -> Celery:
celery = Celery(
broker_url=settings.BROKER_URL,
def get_celery_binding(broker_url: str) -> Celery:
return Celery(
broker_url=broker_url,
broker_connection_retry=False,
broker_connection_retry_on_startup=False,
)

return celery
4 changes: 3 additions & 1 deletion app/shared/db/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from sqlalchemy import engine_from_config, pool

from app.shared.db.models import Base
from app.shared.settings import settings
from app.shared.settings import Settings

settings = Settings() # type: ignore

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down
33 changes: 14 additions & 19 deletions app/shared/db/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
from typing import Any, Generator
from typing import Any

from sqlalchemy import create_engine, event
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy import Engine, create_engine, event
from sqlalchemy.orm import sessionmaker

from app.shared.settings import settings

engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
def make_engine(database_url: str):
engine = create_engine(database_url, connect_args={"check_same_thread": False})

@event.listens_for(engine, "connect")
def set_sqlite_pragma(conn: Any, _: Any) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()

@event.listens_for(engine, "connect")
def set_sqlite_pragma(conn: Any, _: Any) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine


SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def get_session() -> Generator[Session, None, None]:
session: Session = SessionLocal()
try:
yield session
finally:
session.close()
def make_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
7 changes: 7 additions & 0 deletions app/shared/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import logging

logging.basicConfig()

logger = logging.getLogger(__name__)

logger.setLevel(logging.INFO)
8 changes: 0 additions & 8 deletions app/shared/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import sys

from pydantic_settings import BaseSettings


Expand All @@ -13,9 +11,3 @@ class Settings(BaseSettings):
TASK_HARD_TIME_LIMIT: int = 4 * 60 * 60

ENABLE_SHARING: bool = False


if "pytest" in sys.modules:
settings = Settings(_env_file=".env.test") # type: ignore
else:
settings = Settings() # type: ignore
50 changes: 28 additions & 22 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,57 @@
from sqlalchemy_utils import create_database, database_exists, drop_database

import app.shared.db.models as models
from app.shared.db.base import SessionLocal, engine
from app.shared.settings import settings
from app.shared.db.base import make_engine, make_session_local
from app.shared.settings import Settings
from app.web.injections.db import get_session
from app.web.injections.settings import get_settings
from app.web.main import app_factory


def pytest_configure() -> None:
if not database_exists(engine.url):
create_database(engine.url)


def pytest_unconfigure() -> None:
if database_exists(engine.url):
drop_database(engine.url)
@pytest.fixture()
def settings():
return Settings(_env_file=".env.test") # type: ignore


@pytest.fixture()
def auth_headers() -> dict[str, str]:
def auth_headers(settings) -> dict[str, str]:
return {"Authorization": f"Bearer {settings.API_SECRET}"}


@pytest.fixture()
def test_db():
def test_db(settings):
engine = make_engine(settings.DATABASE_URI)

if not database_exists(engine.url):
create_database(engine.url)

models.Base.metadata.create_all(engine)

connection = engine.connect()
yield connection
connection.close()

models.Base.metadata.drop_all(bind=engine)
drop_database(engine.url)


@pytest.fixture()
def db_session(test_db):
with SessionLocal(bind=test_db) as session:
session_local = make_session_local(test_db)
with session_local() as session:
yield session


@pytest.fixture()
def client(db_session):
app = app_factory(lambda: db_session)
def app(db_session, settings):
app = app_factory()
app.dependency_overrides[get_settings] = lambda: settings
app.dependency_overrides[get_session] = lambda: db_session
return app


@pytest.fixture()
def client(app):
client = TestClient(app)
return client

Expand All @@ -66,10 +79,3 @@ def mock_artifact(db_session, mock_job):
db_session.add(artifact)
db_session.commit()
return artifact


@pytest.fixture()
def sharing_enabled():
settings.ENABLE_SHARING = True
yield
settings.ENABLE_SHARING = False
12 changes: 6 additions & 6 deletions app/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fastapi.testclient import TestClient

import app.shared.db.models as models
from app.web.main import app_factory
from app.shared.settings import Settings
from app.web.injections.settings import get_settings


# POST /api/v1/jobs
Expand Down Expand Up @@ -69,9 +68,10 @@ def test_get_job_sharing_disabled(client, mock_job):
assert res.status_code == 401


def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled):
# HACK: delay construction until settings are patched.
client = TestClient(app_factory(lambda: db_session))
def test_get_job_sharing_enabled(client, app, mock_job):
app.dependency_overrides[get_settings] = lambda: Settings(
_env_file=".env.test", ENABLE_SHARING=True # type: ignore
)

res = client.get(
f"/api/v1/jobs/{mock_job.id}",
Expand Down
3 changes: 1 addition & 2 deletions app/web/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from app.shared.db.base import get_session
from app.web.main import app_factory

app = app_factory(get_session)
app = app_factory
Empty file added app/web/injections/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions app/web/injections/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from functools import lru_cache
from typing import Generator

from fastapi import Depends
from sqlalchemy.orm import Session

from app.shared.db.base import make_engine, make_session_local
from app.shared.settings import Settings
from app.web.injections.settings import get_settings


@lru_cache
def session_local(database_url: str):
engine = make_engine(database_url)
return make_session_local(engine)


def get_session_local(settings: Settings = Depends(get_settings)):
return session_local(settings.DATABASE_URI)


def get_session(
session_local=Depends(get_session_local),
) -> Generator[Session, None, None]:
with session_local() as session:
yield session
39 changes: 39 additions & 0 deletions app/web/injections/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from hmac import compare_digest
from typing import Annotated

from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from app.shared.settings import Settings
from app.web.injections.settings import get_settings


def api_key_auth(
credentials: Annotated[
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
],
settings: Annotated[Settings, Depends(get_settings)],
):
validate_credentials(credentials, settings.API_SECRET)


def sharing_auth(
credentials: Annotated[
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
],
settings: Annotated[Settings, Depends(get_settings)],
):
if settings.ENABLE_SHARING:
pass
else:
validate_credentials(credentials, settings.API_SECRET)


def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str):
# use compare_digest to counter timing attacks.
if (
not credentials
or not secret
or not compare_digest(secret, credentials.credentials)
):
raise HTTPException(status_code=401)
8 changes: 8 additions & 0 deletions app/web/injections/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from functools import lru_cache

from app.shared.settings import Settings


@lru_cache
def get_settings():
return Settings() # type: ignore
16 changes: 16 additions & 0 deletions app/web/injections/task_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from functools import lru_cache

from fastapi import Depends

from app.shared.settings import Settings
from app.web.injections.settings import get_settings
from app.web.task_queue import TaskQueue


@lru_cache
def task_queue(broker_url: str):
return TaskQueue(broker_url)


def get_task_queue(settings: Settings = Depends(get_settings)):
return task_queue(settings.BROKER_URL)
30 changes: 14 additions & 16 deletions app/web/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Callable, Generator
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
Expand All @@ -7,18 +7,15 @@

import app.shared.db.models as models
import app.web.dtos as dtos
from app.shared.settings import settings
from app.web.security import authenticate_api_key
from app.web.injections.db import get_session
from app.web.injections.security import api_key_auth, sharing_auth
from app.web.injections.task_queue import get_task_queue
from app.web.task_queue import TaskQueue

DatabaseSession = Annotated[Session, Depends(get_session)]

def app_factory(
session_getter: Callable[[], Generator[Session, None, None]]
) -> FastAPI:
DatabaseSession = Annotated[Session, Depends(session_getter)]

task_queue = TaskQueue()

def app_factory():
app = FastAPI(
description=(
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
Expand All @@ -28,13 +25,13 @@ def app_factory(

api_router = APIRouter(prefix="/api/v1")

@api_router.get("/", response_model=None, status_code=204)
def api_root() -> None:
@api_router.get("/", status_code=204)
def api_root():
return None

@api_router.get(
"/jobs",
dependencies=[Depends(authenticate_api_key)],
dependencies=[Depends(api_key_auth)],
response_model=list[dtos.Job],
summary="Get metadata for all jobs",
)
Expand All @@ -52,7 +49,7 @@ def get_jobs(

@api_router.get(
"/jobs/{id}",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
dependencies=[Depends(sharing_auth)],
response_model=dtos.Job,
summary="Get metadata for one job",
)
Expand All @@ -72,7 +69,7 @@ def get_job(

@api_router.get(
"/jobs/{id}/artifacts",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
dependencies=[Depends(api_key_auth)],
response_model=list[dtos.Artifact],
summary="Get all artifacts for one job",
)
Expand All @@ -93,7 +90,7 @@ def get_artifacts_for_job(

@api_router.delete(
"/jobs/{id}",
dependencies=[Depends(authenticate_api_key)],
dependencies=[Depends(sharing_auth)],
status_code=204,
summary="Delete a job with all artifacts",
)
Expand Down Expand Up @@ -130,14 +127,15 @@ class PostJobPayload(BaseModel):

@api_router.post(
"/jobs",
dependencies=[Depends(authenticate_api_key)],
dependencies=[Depends(api_key_auth)],
response_model=dtos.Job,
status_code=201,
summary="Enqueue a new job",
)
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
task_queue: Annotated[TaskQueue, Depends(get_task_queue)],
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Expand Down
Loading

0 comments on commit 3559aa5

Please sign in to comment.