Skip to content

Commit

Permalink
feat: update sqlalchemy to use standard api instead of async one, sta…
Browse files Browse the repository at this point in the history
…ndard has better documentation, more community support, and cleaner code
  • Loading branch information
lchen-2101 committed Dec 23, 2024
1 parent 4a8db27 commit 267bb53
Show file tree
Hide file tree
Showing 17 changed files with 240 additions and 290 deletions.
40 changes: 2 additions & 38 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ regtech-regex = {git = "https://github.com/cfpb/regtech-regex.git"}

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.3"
pytest-asyncio = "^0.24.0"
aiosqlite = "^0.20.0"
pytest-cov = "^6.0.0"
pytest-mock = "^3.11.1"
pytest-env = "^1.1.4"
Expand Down
55 changes: 55 additions & 0 deletions src/regtech_user_fi_management/config.local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
from pathlib import Path
from typing import Set
from urllib import parse

from pydantic import field_validator, PostgresDsn, ValidationInfo
from pydantic_settings import BaseSettings, SettingsConfigDict

from regtech_api_commons.oauth2.config import KeycloakSettings
from regtech_regex.regex_config import RegexConfigs


JWT_OPTS_PREFIX = "jwt_opts_"

dir_path = os.path.dirname(os.path.realpath(__file__))
env_files_to_load: list[Path | str] = [f"{dir_path}/.env"]
if os.getenv("ENV", "LOCAL") == "LOCAL":
env_files_to_load.append(f"{dir_path}/.env.local")


class Settings(BaseSettings):
inst_db_schema: str = "public"
inst_db_name: str
inst_db_user: str
inst_db_pwd: str
inst_db_host: str
inst_db_scheme: str = "postgresql+psycopg2"
# inst_db_scheme: str = "postgresql+asyncpg"
inst_conn: str | None = None
admin_scopes: Set[str] = set(["query-groups", "manage-users"])
db_logging: bool = True

def __init__(self, **data):
super().__init__(**data)

@field_validator("inst_conn", mode="before")
@classmethod
def build_postgres_dsn(cls, field_value, info: ValidationInfo) -> str:
postgres_dsn = PostgresDsn.build(
scheme=info.data.get("inst_db_scheme"),
username=info.data.get("inst_db_user"),
password=parse.quote(str(info.data.get("inst_db_pwd")), safe=""),
host=info.data.get("inst_db_host"),
path=info.data.get("inst_db_name"),
)
return postgres_dsn.unicode_string()

model_config = SettingsConfigDict(env_file=env_files_to_load, extra="allow")


settings = Settings()

kc_settings = KeycloakSettings(_env_file=env_files_to_load)

regex_configs = RegexConfigs.instance()
2 changes: 1 addition & 1 deletion src/regtech_user_fi_management/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Settings(BaseSettings):
inst_db_user: str
inst_db_pwd: str
inst_db_host: str
inst_db_scheme: str = "postgresql+asyncpg"
inst_db_scheme: str = "postgresql+psycopg2"
inst_conn: str | None = None
admin_scopes: Set[str] = set(["query-groups", "manage-users"])
db_logging: bool = True
Expand Down
8 changes: 4 additions & 4 deletions src/regtech_user_fi_management/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
from regtech_api_commons.api.dependencies import get_email_domain


async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
if not request.user.is_authenticated:
raise RegTechHttpException(
status_code=HTTPStatus.FORBIDDEN, name="Request Forbidden", detail="unauthenticated user"
)
if await email_domain_denied(session, get_email_domain(request.user.email)):
if email_domain_denied(session, get_email_domain(request.user.email)):
raise RegTechHttpException(
status_code=HTTPStatus.FORBIDDEN, name="Request Forbidden", detail="email domain denied"
)


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
return not await repo.is_domain_allowed(session, email)
def email_domain_denied(session: AsyncSession, email: str) -> bool:
return not repo.is_domain_allowed(session, email)
16 changes: 6 additions & 10 deletions src/regtech_user_fi_management/entities/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from sqlalchemy.ext.asyncio import (
create_async_engine,
async_sessionmaker,
async_scoped_session,
)
from asyncio import current_task
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from regtech_user_fi_management.config import settings

engine = create_async_engine(str(settings.inst_conn), echo=settings.db_logging).execution_options(
engine = create_engine(str(settings.inst_conn), echo=settings.db_logging).execution_options(
schema_translate_map={None: settings.inst_db_schema}
)
SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task)
SessionLocal = scoped_session(sessionmaker(engine, expire_on_commit=False))


async def get_session():
def get_session():
session = SessionLocal()
try:
yield session
finally:
await session.close()
session.close()
11 changes: 3 additions & 8 deletions src/regtech_user_fi_management/entities/listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,9 @@ def _insert_history(
return _insert_history


async def setup_dao_listeners():
async with engine.begin() as connection:
fi_history, mapping_history = await connection.run_sync(
lambda conn: (
Table("financial_institutions_history", Base.metadata, autoload_with=conn),
Table("fi_to_type_mapping_history", Base.metadata, autoload_with=conn),
)
)
def setup_dao_listeners():
fi_history = Table("financial_institutions_history", Base.metadata, autoload_with=engine)
mapping_history = Table("fi_to_type_mapping_history", Base.metadata, autoload_with=engine)

insert_fi_history = _setup_fi_history(fi_history, mapping_history)

Expand Down
3 changes: 1 addition & 2 deletions src/regtech_user_fi_management/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from typing import List
from sqlalchemy import ForeignKey, func, String, inspect
from sqlalchemy.orm import Mapped, mapped_column, relationship, DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs


class Base(AsyncAttrs, DeclarativeBase):
class Base(DeclarativeBase):
pass


Expand Down
122 changes: 49 additions & 73 deletions src/regtech_user_fi_management/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List, Sequence, Set

from sqlalchemy import select, func
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.ext.asyncio import AsyncSession

from regtech_api_commons.models.auth import AuthenticatedUser

from .repo_utils import get_associated_sbl_types, query_type
from .repo_utils import get_associated_sbl_types

from regtech_user_fi_management.entities.models.dao import (
FinancialInstitutionDao,
Expand All @@ -25,76 +25,61 @@
)


async def get_institutions(
session: AsyncSession,
def get_institutions(
session: Session,
leis: List[str] | None = None,
domain: str = "",
page: int = 0,
count: int = 100,
) -> Sequence[FinancialInstitutionDao]:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
.options(joinedload(FinancialInstitutionDao.domains))
.limit(count)
.offset(page * count)
)
if leis is not None:
stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis))
elif d := domain.strip():
stmt = stmt.join(FinancialInstitutionDomainDao).filter(FinancialInstitutionDomainDao.domain == d)
res = await session.scalars(stmt)
return res.unique().all()


async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao | None:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
.options(joinedload(FinancialInstitutionDao.domains))
.filter(FinancialInstitutionDao.lei == lei)
)
return await session.scalar(stmt)
query = session.query(FinancialInstitutionDao)
if leis is not None:
query = query.filter(FinancialInstitutionDao.lei.in_(leis))
elif d := domain.strip():
query = query.join(FinancialInstitutionDomainDao).filter(FinancialInstitutionDomainDao.domain == d)
return query.limit(count).offset(page * count).all()


def get_institution(session: Session, lei: str) -> FinancialInstitutionDao | None:
return session.get(FinancialInstitutionDao, lei)

async def get_sbl_types(session: AsyncSession) -> Sequence[SBLInstitutionTypeDao]:
return await query_type(session, SBLInstitutionTypeDao)

def get_sbl_types(session: Session) -> Sequence[SBLInstitutionTypeDao]:
return session.query(SBLInstitutionTypeDao).all()

async def get_hmda_types(session: AsyncSession) -> Sequence[HMDAInstitutionTypeDao]:
return await query_type(session, HMDAInstitutionTypeDao)

def get_hmda_types(session: Session) -> Sequence[HMDAInstitutionTypeDao]:
return session.query(HMDAInstitutionTypeDao).all()

async def get_address_states(session: AsyncSession) -> Sequence[AddressStateDao]:
return await query_type(session, AddressStateDao)

def get_address_states(session: Session) -> Sequence[AddressStateDao]:
return session.query(AddressStateDao).all()

async def get_federal_regulators(session: AsyncSession) -> Sequence[FederalRegulatorDao]:
return await query_type(session, FederalRegulatorDao)

def get_federal_regulators(session: Session) -> Sequence[FederalRegulatorDao]:
return session.query(FederalRegulatorDao).all()

async def upsert_institution(
session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser

def upsert_institution(
session: Session, fi: FinancialInstitutionDto, user: AuthenticatedUser
) -> FinancialInstitutionDao:
async with session.begin():
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)
fi_data.pop("version", None)
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)
fi_data.pop("version", None)

if "sbl_institution_types" in fi_data:
types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types)
fi_data["sbl_institution_types"] = types_association
if "sbl_institution_types" in fi_data:
types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types)
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id))
await session.flush()
await session.refresh(db_fi)
return db_fi
db_fi = session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id))
session.commit()
return db_fi


async def update_sbl_types(
session: AsyncSession, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str]
def update_sbl_types(
session: Session, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str]
) -> FinancialInstitutionDao | None:
if fi := await get_institution(session, lei):
if fi := get_institution(session, lei):
new_types = set(get_associated_sbl_types(lei, user.id, sbl_types))
old_types = set(fi.sbl_institution_types)
add_types = new_types.difference(old_types)
Expand All @@ -104,34 +89,25 @@ async def update_sbl_types(
fi.sbl_institution_types.extend(add_types)
for type in fi.sbl_institution_types:
type.version = fi.version
await session.commit()
"""
load the async relational attributes so dto can be properly serialized
"""
for type in fi.sbl_institution_types:
await type.awaitable_attrs.sbl_type
session.commit()
return fi


async def add_domains(
session: AsyncSession, lei: str, domains: List[FinancialInsitutionDomainCreate]
def add_domains(
session: Session, lei: str, domains: List[FinancialInsitutionDomainCreate]
) -> Set[FinancialInstitutionDomainDao]:
async with session.begin():
daos = set(
map(
lambda dto: FinancialInstitutionDomainDao(domain=dto.domain, lei=lei),
domains,
)
daos = set(
map(
lambda dto: FinancialInstitutionDomainDao(domain=dto.domain, lei=lei),
domains,
)
session.add_all(daos)
await session.commit()
return daos
)
session.add_all(daos)
session.commit()
return daos


async def is_domain_allowed(session: AsyncSession, domain: str) -> bool:
def is_domain_allowed(session: Session, domain: str) -> bool:
if domain:
async with session:
stmt = select(func.count()).filter(DeniedDomainDao.domain == domain)
res = await session.scalar(stmt)
return res == 0
return session.get(DeniedDomainDao, domain) is None
return False
Loading

0 comments on commit 267bb53

Please sign in to comment.