From 3bef614ebf25fcd88f6f974f0300497ebfc45430 Mon Sep 17 00:00:00 2001 From: jcadam14 <41971533+jcadam14@users.noreply.github.com> Date: Thu, 4 Jan 2024 12:01:30 -0700 Subject: [PATCH] 69 add endpoints to retrieve from look up data tables (#74) Closes #69 Added institutions/address-states, institutions/regulators, and institutions/types/{type} endpoints Went with a 'generic' InstitutionTypeDTO (since both sbl and hmda tables just have id and name). The /types/ endpoint uses a type Literal to limit the accepted types to sbl or hmda. It then queries the specific Dao type. If this approach isn't desired I can break that GET into two separate types/sbl and types/hmda endpoints that return the specific DTO type. Updated pytests to test these endpoints and the corresponding get_* functions added to the institution repo. --- src/entities/models/__init__.py | 6 ++-- src/entities/models/dao.py | 11 +++---- src/entities/models/dto.py | 20 ++----------- src/entities/repos/institutions_repo.py | 22 ++++++++++++++ src/entities/repos/repo_utils.py | 12 ++++++++ src/routers/institutions.py | 29 ++++++++++++++++++- tests/api/routers/test_institutions_api.py | 29 +++++++++++++++++++ .../entities/repos/test_institutions_repo.py | 24 +++++++++++++++ 8 files changed, 126 insertions(+), 27 deletions(-) create mode 100644 src/entities/repos/repo_utils.py diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py index 27a17bb..0394759 100644 --- a/src/entities/models/__init__.py +++ b/src/entities/models/__init__.py @@ -16,8 +16,7 @@ "SBLInstitutionTypeDao", "AddressStateDao", "FederalRegulatorDto", - "HMDAInstitutionTypeDto", - "SBLInstitutionTypeDto", + "InstitutionTypeDto", "AddressStateDto", ] @@ -41,7 +40,6 @@ UserProfile, AuthenticatedUser, FederalRegulatorDto, - HMDAInstitutionTypeDto, - SBLInstitutionTypeDto, + InstitutionTypeDto, AddressStateDto, ) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 6e11577..e4bf81c 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -62,16 +62,17 @@ class FederalRegulatorDao(AuditMixin, Base): name: Mapped[str] = mapped_column(unique=True, nullable=False) -class HMDAInstitutionTypeDao(AuditMixin, Base): - __tablename__ = "hmda_institution_type" +class InstitutionTypeMixin(AuditMixin): id: Mapped[str] = mapped_column(index=True, primary_key=True, unique=True) name: Mapped[str] = mapped_column(unique=True) -class SBLInstitutionTypeDao(AuditMixin, Base): +class HMDAInstitutionTypeDao(InstitutionTypeMixin, Base): + __tablename__ = "hmda_institution_type" + + +class SBLInstitutionTypeDao(InstitutionTypeMixin, Base): __tablename__ = "sbl_institution_type" - id: Mapped[str] = mapped_column(index=True, primary_key=True, unique=True) - name: Mapped[str] = mapped_column(unique=True, nullable=False) class AddressStateDao(AuditMixin, Base): diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index 632a68e..192fc9f 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -73,22 +73,8 @@ class Config: from_attributes = True -class HMDAInstitutionTypeBase(BaseModel): +class InstitutionTypeDto(BaseModel): id: str - - -class HMDAInstitutionTypeDto(HMDAInstitutionTypeBase): - name: str - - class Config: - from_attributes = True - - -class SBLInstitutionTypeBase(BaseModel): - id: str - - -class SBLInstitutionTypeDto(SBLInstitutionTypeBase): name: str class Config: @@ -108,8 +94,8 @@ class Config: class FinancialInstitutionWithRelationsDto(FinancialInstitutionDto): primary_federal_regulator: FederalRegulatorDto | None = None - hmda_institution_type: HMDAInstitutionTypeDto | None = None - sbl_institution_type: SBLInstitutionTypeDto | None = None + hmda_institution_type: InstitutionTypeDto | None = None + sbl_institution_type: InstitutionTypeDto | None = None hq_address_state: AddressStateDto domains: List[FinancialInsitutionDomainDto] = [] diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 2af5b7b..177ca03 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -4,12 +4,18 @@ from sqlalchemy.orm import joinedload from sqlalchemy.ext.asyncio import AsyncSession +from .repo_utils import query_type + from entities.models import ( FinancialInstitutionDao, FinancialInstitutionDomainDao, FinancialInstitutionDto, FinancialInsitutionDomainCreate, + HMDAInstitutionTypeDao, + SBLInstitutionTypeDao, DeniedDomainDao, + AddressStateDao, + FederalRegulatorDao, ) @@ -45,6 +51,22 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti return await session.scalar(stmt) +async def get_sbl_types(session: AsyncSession) -> List[SBLInstitutionTypeDao]: + return await query_type(session, SBLInstitutionTypeDao) + + +async def get_hmda_types(session: AsyncSession) -> List[HMDAInstitutionTypeDao]: + return await query_type(session, HMDAInstitutionTypeDao) + + +async def get_address_states(session: AsyncSession) -> List[AddressStateDao]: + return await query_type(session, AddressStateDao) + + +async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulatorDao]: + return await query_type(session, FederalRegulatorDao) + + async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao: async with session.begin(): stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei) diff --git a/src/entities/repos/repo_utils.py b/src/entities/repos/repo_utils.py new file mode 100644 index 0000000..faa48be --- /dev/null +++ b/src/entities/repos/repo_utils.py @@ -0,0 +1,12 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from typing import List, TypeVar + +T = TypeVar("T") + + +async def query_type(session: AsyncSession, type: T) -> List[T]: + async with session.begin(): + stmt = select(type) + res = await session.scalars(stmt) + return res.all() diff --git a/src/routers/institutions.py b/src/routers/institutions.py index a190b05..63961ba 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -3,7 +3,7 @@ from oauth2 import oauth2_admin from util import Router from dependencies import check_domain, parse_leis, get_email_domain -from typing import Annotated, List, Tuple +from typing import Annotated, List, Tuple, Literal from entities.engine import get_session from entities.repos import institutions_repo as repo from entities.models import ( @@ -12,11 +12,16 @@ FinancialInsitutionDomainDto, FinancialInsitutionDomainCreate, FinanicialInstitutionAssociationDto, + InstitutionTypeDto, AuthenticatedUser, + AddressStateDto, + FederalRegulatorDto, ) from sqlalchemy.ext.asyncio import AsyncSession from starlette.authentication import requires +InstitutionType = Literal["sbl", "hmda"] + async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]): request.state.db_session = session @@ -63,6 +68,28 @@ async def get_associated_institutions(request: Request): ] +@router.get("/types/{type}", response_model=List[InstitutionTypeDto]) +@requires("authenticated") +async def get_institution_types(request: Request, type: InstitutionType): + match type: + case "sbl": + return await repo.get_sbl_types(request.state.db_session) + case "hmda": + return await repo.get_hmda_types(request.state.db_session) + + +@router.get("/address-states", response_model=List[AddressStateDto]) +@requires("authenticated") +async def get_address_states(request: Request): + return await repo.get_address_states(request.state.db_session) + + +@router.get("/regulators", response_model=List[FederalRegulatorDto]) +@requires("authenticated") +async def get_federal_regulators(request: Request): + return await repo.get_federal_regulators(request.state.db_session) + + @router.get("/{lei}", response_model=FinancialInstitutionWithRelationsDto) @requires("authenticated") async def get_institution( diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index 4dd1daf..d5f0abd 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -379,3 +379,32 @@ def test_get_associated_institutions_with_no_institutions( assert res.status_code == 200 get_institutions_mock.assert_called_once_with(ANY, []) assert res.json() == [] + + def test_get_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + mock = mocker.patch("entities.repos.institutions_repo.get_sbl_types") + mock.return_value = [] + client = TestClient(app_fixture) + res = client.get("/v1/institutions/types/sbl") + assert res.status_code == 200 + + mock = mocker.patch("entities.repos.institutions_repo.get_hmda_types") + mock.return_value = [] + res = client.get("/v1/institutions/types/hmda") + assert res.status_code == 200 + + res = client.get("/v1/institutions/types/blah") + assert res.status_code == 422 + + def test_get_address_states(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + mock = mocker.patch("entities.repos.institutions_repo.get_address_states") + mock.return_value = [] + client = TestClient(app_fixture) + res = client.get("/v1/institutions/address-states") + assert res.status_code == 200 + + def test_get_federal_regulators(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + mock = mocker.patch("entities.repos.institutions_repo.get_federal_regulators") + mock.return_value = [] + client = TestClient(app_fixture) + res = client.get("/v1/institutions/regulators") + assert res.status_code == 200 diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 50eee18..e8c9d8e 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -132,6 +132,30 @@ async def setup( transaction_session.add(fi_dao_sub_456) await transaction_session.commit() + async def test_get_sbl_types(self, query_session: AsyncSession): + expected_ids = {"SIT1", "SIT2", "SIT3"} + res = await repo.get_sbl_types(query_session) + assert len(res) == 3 + assert set([r.id for r in res]) == expected_ids + + async def test_get_hmda_types(self, query_session: AsyncSession): + expected_ids = {"HIT1", "HIT2", "HIT3"} + res = await repo.get_hmda_types(query_session) + assert len(res) == 3 + assert set([r.id for r in res]) == expected_ids + + async def test_get_address_states(self, query_session: AsyncSession): + expected_codes = {"CA", "GA", "FL"} + res = await repo.get_address_states(query_session) + assert len(res) == 3 + assert set([r.code for r in res]) == expected_codes + + async def test_get_federal_regulators(self, query_session: AsyncSession): + expected_ids = {"FRI1", "FRI2", "FRI3"} + res = await repo.get_federal_regulators(query_session) + assert len(res) == 3 + assert set([r.id for r in res]) == expected_ids + async def test_get_institutions(self, query_session: AsyncSession): res = await repo.get_institutions(query_session) assert len(res) == 3