From 267bb53f9b5e00afaa1ba36c40fc4a90745442d7 Mon Sep 17 00:00:00 2001 From: lchen <73617864+lchen-2101@users.noreply.github.com> Date: Mon, 23 Dec 2024 12:24:33 -0800 Subject: [PATCH] feat: update sqlalchemy to use standard api instead of async one, standard has better documentation, more community support, and cleaner code --- poetry.lock | 40 +---- pyproject.toml | 2 - .../config.local.py | 55 +++++++ src/regtech_user_fi_management/config.py | 2 +- .../dependencies.py | 8 +- .../entities/engine/engine.py | 16 +- .../entities/listeners.py | 11 +- .../entities/models/dao.py | 3 +- .../entities/repos/institutions_repo.py | 122 ++++++--------- .../entities/repos/repo_utils.py | 11 +- src/regtech_user_fi_management/main.py | 2 +- .../routers/institutions.py | 50 +++--- tests/api/conftest.py | 4 +- tests/app/test_config.py | 2 +- tests/app/test_dependencies.py | 8 +- tests/entities/conftest.py | 48 ++---- .../entities/repos/test_institutions_repo.py | 146 +++++++++--------- 17 files changed, 240 insertions(+), 290 deletions(-) create mode 100644 src/regtech_user_fi_management/config.local.py diff --git a/poetry.lock b/poetry.lock index 4d8372e..2ed74cf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,22 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. - -[[package]] -name = "aiosqlite" -version = "0.20.0" -description = "asyncio bridge to the standard sqlite3 module" -optional = false -python-versions = ">=3.8" -files = [ - {file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"}, - {file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"}, -] - -[package.dependencies] -typing_extensions = ">=4.0" - -[package.extras] -dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"] -docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"] +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "alembic" @@ -1178,24 +1160,6 @@ alembic = "*" pytest = ">=6.0" sqlalchemy = "*" -[[package]] -name = "pytest-asyncio" -version = "0.24.0" -description = "Pytest support for asyncio" -optional = false -python-versions = ">=3.8" -files = [ - {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, - {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, -] - -[package.dependencies] -pytest = ">=8.2,<9" - -[package.extras] -docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] -testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] - [[package]] name = "pytest-cov" version = "6.0.0" @@ -1665,4 +1629,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", [metadata] lock-version = "2.0" python-versions = ">=3.12,<4" -content-hash = "9f1660cda4c66099a00302ae3c6400ce0d302f5cd8dc485b0111f5551823c252" +content-hash = "9d8f83e7084fd45ae1f3ff8c6bc62883821f3c10457c9c40d4d799e56eb2d8b9" diff --git a/pyproject.toml b/pyproject.toml index 4c55849..183d792 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,6 @@ regtech-regex = {git = "https://github.com/cfpb/regtech-regex.git"} [tool.poetry.group.dev.dependencies] pytest = "^8.3.3" -pytest-asyncio = "^0.24.0" -aiosqlite = "^0.20.0" pytest-cov = "^6.0.0" pytest-mock = "^3.11.1" pytest-env = "^1.1.4" diff --git a/src/regtech_user_fi_management/config.local.py b/src/regtech_user_fi_management/config.local.py new file mode 100644 index 0000000..66b21ea --- /dev/null +++ b/src/regtech_user_fi_management/config.local.py @@ -0,0 +1,55 @@ +import os +from pathlib import Path +from typing import Set +from urllib import parse + +from pydantic import field_validator, PostgresDsn, ValidationInfo +from pydantic_settings import BaseSettings, SettingsConfigDict + +from regtech_api_commons.oauth2.config import KeycloakSettings +from regtech_regex.regex_config import RegexConfigs + + +JWT_OPTS_PREFIX = "jwt_opts_" + +dir_path = os.path.dirname(os.path.realpath(__file__)) +env_files_to_load: list[Path | str] = [f"{dir_path}/.env"] +if os.getenv("ENV", "LOCAL") == "LOCAL": + env_files_to_load.append(f"{dir_path}/.env.local") + + +class Settings(BaseSettings): + inst_db_schema: str = "public" + inst_db_name: str + inst_db_user: str + inst_db_pwd: str + inst_db_host: str + inst_db_scheme: str = "postgresql+psycopg2" + # inst_db_scheme: str = "postgresql+asyncpg" + inst_conn: str | None = None + admin_scopes: Set[str] = set(["query-groups", "manage-users"]) + db_logging: bool = True + + def __init__(self, **data): + super().__init__(**data) + + @field_validator("inst_conn", mode="before") + @classmethod + def build_postgres_dsn(cls, field_value, info: ValidationInfo) -> str: + postgres_dsn = PostgresDsn.build( + scheme=info.data.get("inst_db_scheme"), + username=info.data.get("inst_db_user"), + password=parse.quote(str(info.data.get("inst_db_pwd")), safe=""), + host=info.data.get("inst_db_host"), + path=info.data.get("inst_db_name"), + ) + return postgres_dsn.unicode_string() + + model_config = SettingsConfigDict(env_file=env_files_to_load, extra="allow") + + +settings = Settings() + +kc_settings = KeycloakSettings(_env_file=env_files_to_load) + +regex_configs = RegexConfigs.instance() diff --git a/src/regtech_user_fi_management/config.py b/src/regtech_user_fi_management/config.py index fac386a..c9445d2 100644 --- a/src/regtech_user_fi_management/config.py +++ b/src/regtech_user_fi_management/config.py @@ -23,7 +23,7 @@ class Settings(BaseSettings): inst_db_user: str inst_db_pwd: str inst_db_host: str - inst_db_scheme: str = "postgresql+asyncpg" + inst_db_scheme: str = "postgresql+psycopg2" inst_conn: str | None = None admin_scopes: Set[str] = set(["query-groups", "manage-users"]) db_logging: bool = True diff --git a/src/regtech_user_fi_management/dependencies.py b/src/regtech_user_fi_management/dependencies.py index 7f49649..7d9045a 100644 --- a/src/regtech_user_fi_management/dependencies.py +++ b/src/regtech_user_fi_management/dependencies.py @@ -9,16 +9,16 @@ from regtech_api_commons.api.dependencies import get_email_domain -async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None: +def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None: if not request.user.is_authenticated: raise RegTechHttpException( status_code=HTTPStatus.FORBIDDEN, name="Request Forbidden", detail="unauthenticated user" ) - if await email_domain_denied(session, get_email_domain(request.user.email)): + if email_domain_denied(session, get_email_domain(request.user.email)): raise RegTechHttpException( status_code=HTTPStatus.FORBIDDEN, name="Request Forbidden", detail="email domain denied" ) -async def email_domain_denied(session: AsyncSession, email: str) -> bool: - return not await repo.is_domain_allowed(session, email) +def email_domain_denied(session: AsyncSession, email: str) -> bool: + return not repo.is_domain_allowed(session, email) diff --git a/src/regtech_user_fi_management/entities/engine/engine.py b/src/regtech_user_fi_management/entities/engine/engine.py index b7d591c..0b32711 100644 --- a/src/regtech_user_fi_management/entities/engine/engine.py +++ b/src/regtech_user_fi_management/entities/engine/engine.py @@ -1,20 +1,16 @@ -from sqlalchemy.ext.asyncio import ( - create_async_engine, - async_sessionmaker, - async_scoped_session, -) -from asyncio import current_task +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker from regtech_user_fi_management.config import settings -engine = create_async_engine(str(settings.inst_conn), echo=settings.db_logging).execution_options( +engine = create_engine(str(settings.inst_conn), echo=settings.db_logging).execution_options( schema_translate_map={None: settings.inst_db_schema} ) -SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) +SessionLocal = scoped_session(sessionmaker(engine, expire_on_commit=False)) -async def get_session(): +def get_session(): session = SessionLocal() try: yield session finally: - await session.close() + session.close() diff --git a/src/regtech_user_fi_management/entities/listeners.py b/src/regtech_user_fi_management/entities/listeners.py index ddae75f..fdf9ec4 100644 --- a/src/regtech_user_fi_management/entities/listeners.py +++ b/src/regtech_user_fi_management/entities/listeners.py @@ -67,14 +67,9 @@ def _insert_history( return _insert_history -async def setup_dao_listeners(): - async with engine.begin() as connection: - fi_history, mapping_history = await connection.run_sync( - lambda conn: ( - Table("financial_institutions_history", Base.metadata, autoload_with=conn), - Table("fi_to_type_mapping_history", Base.metadata, autoload_with=conn), - ) - ) +def setup_dao_listeners(): + fi_history = Table("financial_institutions_history", Base.metadata, autoload_with=engine) + mapping_history = Table("fi_to_type_mapping_history", Base.metadata, autoload_with=engine) insert_fi_history = _setup_fi_history(fi_history, mapping_history) diff --git a/src/regtech_user_fi_management/entities/models/dao.py b/src/regtech_user_fi_management/entities/models/dao.py index d0261f6..399d83e 100644 --- a/src/regtech_user_fi_management/entities/models/dao.py +++ b/src/regtech_user_fi_management/entities/models/dao.py @@ -2,10 +2,9 @@ from typing import List from sqlalchemy import ForeignKey, func, String, inspect from sqlalchemy.orm import Mapped, mapped_column, relationship, DeclarativeBase -from sqlalchemy.ext.asyncio import AsyncAttrs -class Base(AsyncAttrs, DeclarativeBase): +class Base(DeclarativeBase): pass diff --git a/src/regtech_user_fi_management/entities/repos/institutions_repo.py b/src/regtech_user_fi_management/entities/repos/institutions_repo.py index d7210e1..062efe7 100644 --- a/src/regtech_user_fi_management/entities/repos/institutions_repo.py +++ b/src/regtech_user_fi_management/entities/repos/institutions_repo.py @@ -1,12 +1,12 @@ from typing import List, Sequence, Set from sqlalchemy import select, func -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Session, joinedload from sqlalchemy.ext.asyncio import AsyncSession from regtech_api_commons.models.auth import AuthenticatedUser -from .repo_utils import get_associated_sbl_types, query_type +from .repo_utils import get_associated_sbl_types from regtech_user_fi_management.entities.models.dao import ( FinancialInstitutionDao, @@ -25,76 +25,61 @@ ) -async def get_institutions( - session: AsyncSession, +def get_institutions( + session: Session, leis: List[str] | None = None, domain: str = "", page: int = 0, count: int = 100, ) -> Sequence[FinancialInstitutionDao]: - async with session.begin(): - stmt = ( - select(FinancialInstitutionDao) - .options(joinedload(FinancialInstitutionDao.domains)) - .limit(count) - .offset(page * count) - ) - if leis is not None: - stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis)) - elif d := domain.strip(): - stmt = stmt.join(FinancialInstitutionDomainDao).filter(FinancialInstitutionDomainDao.domain == d) - res = await session.scalars(stmt) - return res.unique().all() - - -async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao | None: - async with session.begin(): - stmt = ( - select(FinancialInstitutionDao) - .options(joinedload(FinancialInstitutionDao.domains)) - .filter(FinancialInstitutionDao.lei == lei) - ) - return await session.scalar(stmt) + query = session.query(FinancialInstitutionDao) + if leis is not None: + query = query.filter(FinancialInstitutionDao.lei.in_(leis)) + elif d := domain.strip(): + query = query.join(FinancialInstitutionDomainDao).filter(FinancialInstitutionDomainDao.domain == d) + return query.limit(count).offset(page * count).all() + +def get_institution(session: Session, lei: str) -> FinancialInstitutionDao | None: + return session.get(FinancialInstitutionDao, lei) -async def get_sbl_types(session: AsyncSession) -> Sequence[SBLInstitutionTypeDao]: - return await query_type(session, SBLInstitutionTypeDao) +def get_sbl_types(session: Session) -> Sequence[SBLInstitutionTypeDao]: + return session.query(SBLInstitutionTypeDao).all() -async def get_hmda_types(session: AsyncSession) -> Sequence[HMDAInstitutionTypeDao]: - return await query_type(session, HMDAInstitutionTypeDao) +def get_hmda_types(session: Session) -> Sequence[HMDAInstitutionTypeDao]: + return session.query(HMDAInstitutionTypeDao).all() -async def get_address_states(session: AsyncSession) -> Sequence[AddressStateDao]: - return await query_type(session, AddressStateDao) +def get_address_states(session: Session) -> Sequence[AddressStateDao]: + return session.query(AddressStateDao).all() -async def get_federal_regulators(session: AsyncSession) -> Sequence[FederalRegulatorDao]: - return await query_type(session, FederalRegulatorDao) +def get_federal_regulators(session: Session) -> Sequence[FederalRegulatorDao]: + return session.query(FederalRegulatorDao).all() -async def upsert_institution( - session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser + +def upsert_institution( + session: Session, fi: FinancialInstitutionDto, user: AuthenticatedUser ) -> FinancialInstitutionDao: - async with session.begin(): - fi_data = fi.__dict__.copy() - fi_data.pop("_sa_instance_state", None) - fi_data.pop("version", None) + fi_data = fi.__dict__.copy() + fi_data.pop("_sa_instance_state", None) + fi_data.pop("version", None) - if "sbl_institution_types" in fi_data: - types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types) - fi_data["sbl_institution_types"] = types_association + if "sbl_institution_types" in fi_data: + types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types) + fi_data["sbl_institution_types"] = types_association - db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id)) - await session.flush() - await session.refresh(db_fi) - return db_fi + db_fi = session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id)) + session.commit() + return db_fi -async def update_sbl_types( - session: AsyncSession, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str] +def update_sbl_types( + session: Session, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str] ) -> FinancialInstitutionDao | None: - if fi := await get_institution(session, lei): + if fi := get_institution(session, lei): new_types = set(get_associated_sbl_types(lei, user.id, sbl_types)) old_types = set(fi.sbl_institution_types) add_types = new_types.difference(old_types) @@ -104,34 +89,25 @@ async def update_sbl_types( fi.sbl_institution_types.extend(add_types) for type in fi.sbl_institution_types: type.version = fi.version - await session.commit() - """ - load the async relational attributes so dto can be properly serialized - """ - for type in fi.sbl_institution_types: - await type.awaitable_attrs.sbl_type + session.commit() return fi -async def add_domains( - session: AsyncSession, lei: str, domains: List[FinancialInsitutionDomainCreate] +def add_domains( + session: Session, lei: str, domains: List[FinancialInsitutionDomainCreate] ) -> Set[FinancialInstitutionDomainDao]: - async with session.begin(): - daos = set( - map( - lambda dto: FinancialInstitutionDomainDao(domain=dto.domain, lei=lei), - domains, - ) + daos = set( + map( + lambda dto: FinancialInstitutionDomainDao(domain=dto.domain, lei=lei), + domains, ) - session.add_all(daos) - await session.commit() - return daos + ) + session.add_all(daos) + session.commit() + return daos -async def is_domain_allowed(session: AsyncSession, domain: str) -> bool: +def is_domain_allowed(session: Session, domain: str) -> bool: if domain: - async with session: - stmt = select(func.count()).filter(DeniedDomainDao.domain == domain) - res = await session.scalar(stmt) - return res == 0 + return session.get(DeniedDomainDao, domain) is None return False diff --git a/src/regtech_user_fi_management/entities/repos/repo_utils.py b/src/regtech_user_fi_management/entities/repos/repo_utils.py index 809bdc4..8a165f6 100644 --- a/src/regtech_user_fi_management/entities/repos/repo_utils.py +++ b/src/regtech_user_fi_management/entities/repos/repo_utils.py @@ -1,19 +1,10 @@ -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from typing import Sequence, TypeVar, Type +from typing import Sequence, TypeVar from regtech_user_fi_management.entities.models.dao import Base, SblTypeMappingDao from regtech_user_fi_management.entities.models.dto import SblTypeAssociationDto T = TypeVar("T", bound=Base) -async def query_type(session: AsyncSession, type: Type[T]) -> Sequence[T]: - async with session.begin(): - stmt = select(type) - res = await session.scalars(stmt) - return res.all() - - def get_associated_sbl_types( lei: str, user_id: str, types: Sequence[SblTypeAssociationDto | str] ) -> Sequence[SblTypeMappingDao]: diff --git a/src/regtech_user_fi_management/main.py b/src/regtech_user_fi_management/main.py index 66fc277..0a2432d 100644 --- a/src/regtech_user_fi_management/main.py +++ b/src/regtech_user_fi_management/main.py @@ -41,7 +41,7 @@ async def lifespan(app_: FastAPI): log.info("Starting up...") log.info("run alembic upgrade head...") run_migrations() - await setup_dao_listeners() + setup_dao_listeners() yield log.info("Shutting down...") diff --git a/src/regtech_user_fi_management/routers/institutions.py b/src/regtech_user_fi_management/routers/institutions.py index 6fd4235..0b3ee28 100644 --- a/src/regtech_user_fi_management/routers/institutions.py +++ b/src/regtech_user_fi_management/routers/institutions.py @@ -22,7 +22,7 @@ SblTypeAssociationPatchDto, VersionedData, ) -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from starlette.authentication import requires from regtech_api_commons.models.auth import AuthenticatedUser from regtech_api_commons.api.exceptions import RegTechHttpException @@ -38,7 +38,7 @@ InstitutionType = Literal["sbl", "hmda"] -async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]): +def set_db(request: Request, session: Annotated[Session, Depends(get_session)]): request.state.db_session = session @@ -49,33 +49,33 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_ "/", response_model=List[FinancialInstitutionWithRelationsDto], dependencies=[Depends(verify_institution_search)] ) @requires("authenticated") -async def get_institutions( +def get_institutions( request: Request, leis: List[str] = Depends(parse_leis), domain: str = "", page: int = 0, count: int = 100, ): - return await repo.get_institutions(request.state.db_session, leis, domain, page, count) + return repo.get_institutions(request.state.db_session, leis, domain, page, count) @router.post("/", response_model=Tuple[str, FinancialInstitutionWithRelationsDto], dependencies=[Depends(check_domain)]) @requires(["query-groups", "manage-users"]) -async def create_institution( +def create_institution( request: Request, fi: FinancialInstitutionDto, ): - db_fi = await repo.upsert_institution(request.state.db_session, fi, request.user) + db_fi = repo.upsert_institution(request.state.db_session, fi, request.user) kc_id = oauth2_admin.upsert_group(fi.lei, fi.name) return kc_id, db_fi @router.get("/associated", response_model=List[FinancialInstitutionAssociationDto]) @requires("authenticated") -async def get_associated_institutions(request: Request): +def get_associated_institutions(request: Request): user: AuthenticatedUser = request.user email_domain = get_email_domain(user.email) - associated_institutions = await repo.get_institutions(request.state.db_session, user.institutions) + associated_institutions = repo.get_institutions(request.state.db_session, user.institutions) return [ FinancialInstitutionAssociationDto( **institution.__dict__, @@ -87,35 +87,35 @@ async def get_associated_institutions(request: Request): @router.get("/types/{type}", response_model=List[InstitutionTypeDto]) @requires("authenticated") -async def get_institution_types(request: Request, type: InstitutionType): +def get_institution_types(request: Request, type: InstitutionType): match type: case "sbl": - return await repo.get_sbl_types(request.state.db_session) + return repo.get_sbl_types(request.state.db_session) case "hmda": - return await repo.get_hmda_types(request.state.db_session) + return repo.get_hmda_types(request.state.db_session) @router.get("/address-states", response_model=List[AddressStateDto]) @requires("authenticated") -async def get_address_states(request: Request): - return await repo.get_address_states(request.state.db_session) +def get_address_states(request: Request): + return repo.get_address_states(request.state.db_session) @router.get("/regulators", response_model=List[FederalRegulatorDto]) @requires("authenticated") -async def get_federal_regulators(request: Request): - return await repo.get_federal_regulators(request.state.db_session) +def get_federal_regulators(request: Request): + return repo.get_federal_regulators(request.state.db_session) @router.get( "/{lei}", response_model=FinancialInstitutionWithRelationsDto, dependencies=[Depends(verify_user_lei_relation)] ) @requires("authenticated") -async def get_institution( +def get_institution( request: Request, lei: str, ): - res = await repo.get_institution(request.state.db_session, lei) + res = repo.get_institution(request.state.db_session, lei) if not res: raise RegTechHttpException(HTTPStatus.NOT_FOUND, name="Institution Not Found", detail=f"{lei} not found.") return res @@ -127,10 +127,10 @@ async def get_institution( dependencies=[Depends(verify_user_lei_relation)], ) @requires("authenticated") -async def get_types(request: Request, response: Response, lei: str, type: InstitutionType): +def get_types(request: Request, response: Response, lei: str, type: InstitutionType): match type: case "sbl": - if fi := await repo.get_institution(request.state.db_session, lei): + if fi := repo.get_institution(request.state.db_session, lei): return VersionedData(version=fi.version, data=fi.sbl_institution_types) else: response.status_code = HTTPStatus.NO_CONTENT @@ -146,12 +146,12 @@ async def get_types(request: Request, response: Response, lei: str, type: Instit dependencies=[Depends(verify_user_lei_relation)], ) @requires("authenticated") -async def update_types( +def update_types( request: Request, response: Response, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto ): match type: case "sbl": - if fi := await repo.update_sbl_types( + if fi := repo.update_sbl_types( request.state.db_session, request.user, lei, types_patch.sbl_institution_types ): return VersionedData(version=fi.version, data=fi.sbl_institution_types) @@ -165,14 +165,14 @@ async def update_types( @router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto], dependencies=[Depends(check_domain)]) @requires(["query-groups", "manage-users"]) -async def add_domains( +def add_domains( request: Request, lei: str, domains: List[FinancialInsitutionDomainCreate], ): - return await repo.add_domains(request.state.db_session, lei, domains) + return repo.add_domains(request.state.db_session, lei, domains) @router.get("/domains/allowed", response_model=bool) -async def is_domain_allowed(request: Request, domain: str): - return await repo.is_domain_allowed(request.state.db_session, domain) +def is_domain_allowed(request: Request, domain: str): + return repo.is_domain_allowed(request.state.db_session, domain) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index b64302e..c896fc2 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -19,8 +19,8 @@ @pytest.fixture def app_fixture(mocker: MockerFixture) -> FastAPI: - mocked_engine = mocker.patch("sqlalchemy.ext.asyncio.create_async_engine") - MockedEngine = mocker.patch("sqlalchemy.ext.asyncio.AsyncEngine") + mocked_engine = mocker.patch("sqlalchemy.create_engine") + MockedEngine = mocker.patch("sqlalchemy.Engine") mocked_engine.return_value = MockedEngine.return_value mocker.patch("fastapi.security.OAuth2AuthorizationCodeBearer") domain_denied_mock = mocker.patch("regtech_user_fi_management.dependencies.email_domain_denied") diff --git a/tests/app/test_config.py b/tests/app/test_config.py index 71cf496..1ca1c1b 100644 --- a/tests/app/test_config.py +++ b/tests/app/test_config.py @@ -10,4 +10,4 @@ def test_postgres_dsn_building(): "inst_db_scehma": "test", } settings = Settings(**mock_config) - assert str(settings.inst_conn) == "postgresql+asyncpg://user:%5Cz9-%2Ftgb76%23%40@test:5432/test" + assert str(settings.inst_conn) == "postgresql+psycopg2://user:%5Cz9-%2Ftgb76%23%40@test:5432/test" diff --git a/tests/app/test_dependencies.py b/tests/app/test_dependencies.py index 451ff63..183bb71 100644 --- a/tests/app/test_dependencies.py +++ b/tests/app/test_dependencies.py @@ -9,23 +9,23 @@ def mock_session(mocker: MockerFixture) -> AsyncSession: return mocker.patch("sqlalchemy.ext.asyncio.AsyncSession").return_value -async def test_domain_denied(mocker: MockerFixture, mock_session: AsyncSession): +def test_domain_denied(mocker: MockerFixture, mock_session: AsyncSession): domain_allowed_mock = mocker.patch("regtech_user_fi_management.entities.repos.institutions_repo.is_domain_allowed") domain_allowed_mock.return_value = False from regtech_user_fi_management.dependencies import email_domain_denied denied_domain = "denied.domain" - assert await email_domain_denied(mock_session, denied_domain) is True + assert email_domain_denied(mock_session, denied_domain) is True domain_allowed_mock.assert_called_once_with(mock_session, denied_domain) -async def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession): +def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession): domain_allowed_mock = mocker.patch("regtech_user_fi_management.entities.repos.institutions_repo.is_domain_allowed") domain_allowed_mock.return_value = True from regtech_user_fi_management.dependencies import email_domain_denied allowed_domain = "allowed.domain" - assert await email_domain_denied(mock_session, allowed_domain) is False + assert email_domain_denied(mock_session, allowed_domain) is False domain_allowed_mock.assert_called_once_with(mock_session, allowed_domain) diff --git a/tests/entities/conftest.py b/tests/entities/conftest.py index 40a7f7b..c736fbb 100644 --- a/tests/entities/conftest.py +++ b/tests/entities/conftest.py @@ -1,61 +1,41 @@ -import asyncio import pytest -from asyncio import current_task -from sqlalchemy.ext.asyncio import ( - create_async_engine, - AsyncEngine, - async_scoped_session, - async_sessionmaker, -) +from sqlalchemy import create_engine, Engine +from sqlalchemy.orm import scoped_session, sessionmaker from regtech_user_fi_management.entities.models.dao import Base -@pytest.fixture(scope="session") -def event_loop(): - loop = asyncio.get_event_loop() - try: - yield loop - finally: - loop.close() - - @pytest.fixture(scope="session") def engine(): - return create_async_engine("sqlite+aiosqlite://") + return create_engine("sqlite://") @pytest.fixture(scope="function", autouse=True) -async def setup_db( +def setup_db( request: pytest.FixtureRequest, - engine: AsyncEngine, - event_loop: asyncio.AbstractEventLoop, + engine: Engine, ): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - def teardown(): - async def td(): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) + Base.metadata.create_all(bind=engine) - event_loop.run_until_complete(td()) + def teardown(): + Base.metadata.drop_all(bind=engine) request.addfinalizer(teardown) @pytest.fixture(scope="function") -async def transaction_session(session_generator: async_scoped_session): - async with session_generator() as session: +def transaction_session(session_generator: scoped_session): + with session_generator() as session: yield session @pytest.fixture(scope="function") -async def query_session(session_generator: async_scoped_session): - async with session_generator() as session: +def query_session(session_generator: scoped_session): + with session_generator() as session: yield session @pytest.fixture(scope="function") -def session_generator(engine: AsyncEngine): - return async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) +def session_generator(engine: Engine): + return scoped_session(sessionmaker(engine, expire_on_commit=False)) diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 789f57d..c2c2521 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -1,6 +1,6 @@ import pytest from pytest_mock import MockerFixture -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from regtech_user_fi_management.entities.models.dto import ( FinancialInstitutionDto, @@ -25,9 +25,9 @@ class TestInstitutionsRepo: auth_user: AuthenticatedUser = AuthenticatedUser.from_claim({"id": "test_user_id"}) @pytest.fixture(scope="function", autouse=True) - async def setup( + def setup( self, - transaction_session: AsyncSession, + transaction_session: Session, ): state_ga, state_ca, state_fl = ( AddressStateDao(code="GA", name="Georgia"), @@ -148,65 +148,65 @@ async def setup( transaction_session.add(fi_dao_123) transaction_session.add(fi_dao_456) transaction_session.add(fi_dao_sub_456) - await transaction_session.commit() + transaction_session.commit() - async def test_get_sbl_types(self, query_session: AsyncSession): + def test_get_sbl_types(self, query_session: Session): expected_ids = {"1", "2", "13"} - res = await repo.get_sbl_types(query_session) + res = repo.get_sbl_types(query_session) assert len(res) == 3 assert set([r.id for r in res]) == expected_ids - async def test_get_hmda_types(self, query_session: AsyncSession): + def test_get_hmda_types(self, query_session: Session): expected_ids = {"HIT1", "HIT2", "HIT3"} - res = await repo.get_hmda_types(query_session) + res = repo.get_hmda_types(query_session) assert len(res) == 3 assert set([r.id for r in res]) == expected_ids - async def test_get_address_states(self, query_session: AsyncSession): + def test_get_address_states(self, query_session: Session): expected_codes = {"CA", "GA", "FL"} - res = await repo.get_address_states(query_session) + res = repo.get_address_states(query_session) assert len(res) == 3 assert set([r.code for r in res]) == expected_codes - async def test_get_federal_regulators(self, query_session: AsyncSession): + def test_get_federal_regulators(self, query_session: Session): expected_ids = {"FRI1", "FRI2", "FRI3"} - res = await repo.get_federal_regulators(query_session) + res = repo.get_federal_regulators(query_session) assert len(res) == 3 assert set([r.id for r in res]) == expected_ids - async def test_get_institutions(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session) + def test_get_institutions(self, query_session: Session): + res = repo.get_institutions(query_session) assert len(res) == 3 - async def test_get_institutions_by_domain(self, query_session: AsyncSession): + def test_get_institutions_by_domain(self, query_session: Session): # verify 'generic' domain queries don't work - res = await repo.get_institutions(query_session, domain="bank") + res = repo.get_institutions(query_session, domain="bank") assert len(res) == 0 - res = await repo.get_institutions(query_session, domain="test.bank.1") + res = repo.get_institutions(query_session, domain="test.bank.1") assert len(res) == 1 # shouldn't find sub.test.bank.2 - res = await repo.get_institutions(query_session, domain="test.bank.2") + res = repo.get_institutions(query_session, domain="test.bank.2") assert len(res) == 1 - res = await repo.get_institutions(query_session, domain="sub.test.bank.2") + res = repo.get_institutions(query_session, domain="sub.test.bank.2") assert len(res) == 1 - async def test_get_institutions_by_domain_not_existing(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, domain="testing.bank") + def test_get_institutions_by_domain_not_existing(self, query_session: Session): + res = repo.get_institutions(query_session, domain="testing.bank") assert len(res) == 0 - async def test_get_institutions_by_lei_list(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK123000000000", "TESTBANK456000000000"]) + def test_get_institutions_by_lei_list(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK123000000000", "TESTBANK456000000000"]) assert len(res) == 2 - async def test_get_institutions_by_lei_list_item_not_existing(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["0123NOTTESTBANK01234"]) + def test_get_institutions_by_lei_list_item_not_existing(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["0123NOTTESTBANK01234"]) assert len(res) == 0 - async def test_empty_state(self, transaction_session: AsyncSession): - db_fi = await repo.upsert_institution( + def test_empty_state(self, transaction_session: Session): + db_fi = repo.upsert_institution( transaction_session, FinancialInstitutionDto( name="New Bank 123", @@ -235,14 +235,14 @@ async def test_empty_state(self, transaction_session: AsyncSession): self.auth_user, ) assert db_fi.domains == [] - res = await repo.get_institutions(transaction_session) + res = repo.get_institutions(transaction_session) assert len(res) == 4 new_sbl_types = next(iter([fi for fi in res if fi.lei == "NEWBANK1230000000000"])).sbl_institution_types assert len(new_sbl_types) == 1 assert next(iter(new_sbl_types)).sbl_type.name == "Test SBL Instituion ID 1" - async def test_add_institution(self, transaction_session: AsyncSession): - db_fi = await repo.upsert_institution( + def test_add_institution(self, transaction_session: Session): + db_fi = repo.upsert_institution( transaction_session, FinancialInstitutionDto( name="New Bank 123", @@ -271,16 +271,14 @@ async def test_add_institution(self, transaction_session: AsyncSession): self.auth_user, ) assert db_fi.domains == [] - res = await repo.get_institutions(transaction_session) + res = repo.get_institutions(transaction_session) assert len(res) == 4 new_sbl_types = next(iter([fi for fi in res if fi.lei == "NEWBANK1230000000000"])).sbl_institution_types assert len(new_sbl_types) == 1 assert next(iter(new_sbl_types)).sbl_type.name == "Test SBL Instituion ID 1" - async def test_add_institution_only_required_fields( - self, transaction_session: AsyncSession, query_session: AsyncSession - ): - await repo.upsert_institution( + def test_add_institution_only_required_fields(self, transaction_session: Session, query_session: Session): + repo.upsert_institution( transaction_session, FinancialInstitutionDto( name="Minimal Bank 123", @@ -293,15 +291,13 @@ async def test_add_institution_only_required_fields( ), self.auth_user, ) - res = await repo.get_institution(query_session, "MINBANK1230000000000") + res = repo.get_institution(query_session, "MINBANK1230000000000") assert res is not None assert res.tax_id is None - async def test_add_institution_missing_required_fields( - self, transaction_session: AsyncSession, query_session: AsyncSession - ): + def test_add_institution_missing_required_fields(self, transaction_session: Session, query_session: Session): with pytest.raises(Exception) as e: - await repo.upsert_institution( + repo.upsert_institution( transaction_session, FinancialInstitutionDto( name="Minimal Bank 123", @@ -310,11 +306,11 @@ async def test_add_institution_missing_required_fields( self.auth_user, ) assert "field required" in str(e.value).lower() - res = await repo.get_institution(query_session, "MINBANK1230000000000") + res = repo.get_institution(query_session, "MINBANK1230000000000") assert res is None - async def test_update_institution(self, transaction_session: AsyncSession): - await repo.upsert_institution( + def test_update_institution(self, transaction_session: Session): + repo.upsert_institution( transaction_session, FinancialInstitutionDto( name="Test Bank 234", @@ -327,79 +323,79 @@ async def test_update_institution(self, transaction_session: AsyncSession): ), self.auth_user, ) - res = await repo.get_institutions(transaction_session) + res = repo.get_institutions(transaction_session) assert len(res) == 3 assert res[0].name == "Test Bank 234" - async def test_add_domains(self, transaction_session: AsyncSession, query_session: AsyncSession): - await repo.add_domains( + def test_add_domains(self, transaction_session: Session, query_session: Session): + repo.add_domains( transaction_session, "TESTBANK123000000000", [FinancialInsitutionDomainCreate(domain="bank.test")], ) - fi = await repo.get_institution(query_session, "TESTBANK123000000000") + transaction_session.expunge_all() + fi = repo.get_institution(query_session, "TESTBANK123000000000") assert len(fi.domains) == 2 - async def test_domain_allowed(self, transaction_session: AsyncSession): + def test_domain_allowed(self, transaction_session: Session): denied_domain = DeniedDomainDao(domain="yahoo.com") transaction_session.add(denied_domain) - await transaction_session.commit() - assert await repo.is_domain_allowed(transaction_session, "yahoo.com") is False - assert await repo.is_domain_allowed(transaction_session, "gmail.com") is True + transaction_session.commit() + assert repo.is_domain_allowed(transaction_session, "yahoo.com") is False + assert repo.is_domain_allowed(transaction_session, "gmail.com") is True - async def test_institution_mapped_to_state_valid(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK123000000000"]) + def test_institution_mapped_to_state_valid(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK123000000000"]) assert res[0].hq_address_state.name == "Georgia" - async def test_institution_mapped_to_state_invalid(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK456000000000"]) + def test_institution_mapped_to_state_invalid(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK456000000000"]) assert res[0].hq_address_state.name != "Georgia" - async def test_institution_mapped_to_federal_regulator_valid(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK456000000000"]) + def test_institution_mapped_to_federal_regulator_valid(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK456000000000"]) assert res[0].primary_federal_regulator.name != "Test Federal Regulator ID 1" - async def test_institution_mapped_to_federal_regulator_invalid(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK123000000000"]) + def test_institution_mapped_to_federal_regulator_invalid(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK123000000000"]) assert res[0].primary_federal_regulator.name == "Test Federal Regulator ID 1" - async def test_institution_mapped_to_hmda_it_valid(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK123000000000"]) + def test_institution_mapped_to_hmda_it_valid(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK123000000000"]) assert res[0].hmda_institution_type.name == "Test HMDA Instituion ID 1" - async def test_institution_mapped_to_hmda_it_invalid(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK456000000000"]) + def test_institution_mapped_to_hmda_it_invalid(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK456000000000"]) assert res[0].hmda_institution_type.name != "Test HMDA Instituion ID 1" - async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK123000000000"]) + def test_institution_mapped_to_sbl_it_valid(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK123000000000"]) assert res[0].sbl_institution_types[0].sbl_type.name == "Test SBL Instituion ID 1" - async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, leis=["TESTBANK456000000000"]) + def test_institution_mapped_to_sbl_it_invalid(self, query_session: Session): + res = repo.get_institutions(query_session, leis=["TESTBANK456000000000"]) assert res[0].sbl_institution_types[0].sbl_type.name != "Test SBL Instituion ID 1" - async def test_update_sbl_institution_types( - self, mocker: MockerFixture, query_session: AsyncSession, transaction_session: AsyncSession + def test_update_sbl_institution_types( + self, mocker: MockerFixture, query_session: Session, transaction_session: Session ): test_lei = "TESTBANK123000000000" - existing_inst = await repo.get_institution(query_session, test_lei) + existing_inst = repo.get_institution(query_session, test_lei) + query_session.expunge(existing_inst) sbl_types = [ SblTypeAssociationDto(id="1"), SblTypeAssociationDto(id="2"), SblTypeAssociationDto(id="13", details="test"), ] commit_spy = mocker.patch.object(transaction_session, "commit", wraps=transaction_session.commit) - updated_inst = await repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types) + updated_inst = repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types) commit_spy.assert_called_once() assert len(existing_inst.sbl_institution_types) == 1 assert len(updated_inst.sbl_institution_types) == 3 diffs = set(updated_inst.sbl_institution_types).difference(set(existing_inst.sbl_institution_types)) assert len(diffs) == 2 - async def test_update_sbl_institution_types_inst_non_exist( - self, mocker: MockerFixture, transaction_session: AsyncSession - ): + def test_update_sbl_institution_types_inst_non_exist(self, mocker: MockerFixture, transaction_session: Session): test_lei = "NONEXISTINGBANK00000" sbl_types = [ SblTypeAssociationDto(id="1"), @@ -407,6 +403,6 @@ async def test_update_sbl_institution_types_inst_non_exist( SblTypeAssociationDto(id="13", details="test"), ] commit_spy = mocker.patch.object(transaction_session, "commit", wraps=transaction_session.commit) - res = await repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types) + res = repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types) commit_spy.assert_not_called() assert res is None