Skip to content

Commit

Permalink
69 add endpoints to retrieve from look up data tables (#74)
Browse files Browse the repository at this point in the history
Closes #69 

Added institutions/address-states, institutions/regulators, and
institutions/types/{type} endpoints

Went with a 'generic' InstitutionTypeDTO (since both sbl and hmda tables
just have id and name). The /types/ endpoint uses a type Literal to
limit the accepted types to sbl or hmda. It then queries the specific
Dao type. If this approach isn't desired I can break that GET into two
separate types/sbl and types/hmda endpoints that return the specific DTO
type.

Updated pytests to test these endpoints and the corresponding get_*
functions added to the institution repo.
  • Loading branch information
jcadam14 authored Jan 4, 2024
1 parent 58c1a3f commit 3bef614
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 27 deletions.
6 changes: 2 additions & 4 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
"SBLInstitutionTypeDao",
"AddressStateDao",
"FederalRegulatorDto",
"HMDAInstitutionTypeDto",
"SBLInstitutionTypeDto",
"InstitutionTypeDto",
"AddressStateDto",
]

Expand All @@ -41,7 +40,6 @@
UserProfile,
AuthenticatedUser,
FederalRegulatorDto,
HMDAInstitutionTypeDto,
SBLInstitutionTypeDto,
InstitutionTypeDto,
AddressStateDto,
)
11 changes: 6 additions & 5 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,17 @@ class FederalRegulatorDao(AuditMixin, Base):
name: Mapped[str] = mapped_column(unique=True, nullable=False)


class HMDAInstitutionTypeDao(AuditMixin, Base):
__tablename__ = "hmda_institution_type"
class InstitutionTypeMixin(AuditMixin):
id: Mapped[str] = mapped_column(index=True, primary_key=True, unique=True)
name: Mapped[str] = mapped_column(unique=True)


class SBLInstitutionTypeDao(AuditMixin, Base):
class HMDAInstitutionTypeDao(InstitutionTypeMixin, Base):
__tablename__ = "hmda_institution_type"


class SBLInstitutionTypeDao(InstitutionTypeMixin, Base):
__tablename__ = "sbl_institution_type"
id: Mapped[str] = mapped_column(index=True, primary_key=True, unique=True)
name: Mapped[str] = mapped_column(unique=True, nullable=False)


class AddressStateDao(AuditMixin, Base):
Expand Down
20 changes: 3 additions & 17 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,8 @@ class Config:
from_attributes = True


class HMDAInstitutionTypeBase(BaseModel):
class InstitutionTypeDto(BaseModel):
id: str


class HMDAInstitutionTypeDto(HMDAInstitutionTypeBase):
name: str

class Config:
from_attributes = True


class SBLInstitutionTypeBase(BaseModel):
id: str


class SBLInstitutionTypeDto(SBLInstitutionTypeBase):
name: str

class Config:
Expand All @@ -108,8 +94,8 @@ class Config:

class FinancialInstitutionWithRelationsDto(FinancialInstitutionDto):
primary_federal_regulator: FederalRegulatorDto | None = None
hmda_institution_type: HMDAInstitutionTypeDto | None = None
sbl_institution_type: SBLInstitutionTypeDto | None = None
hmda_institution_type: InstitutionTypeDto | None = None
sbl_institution_type: InstitutionTypeDto | None = None
hq_address_state: AddressStateDto
domains: List[FinancialInsitutionDomainDto] = []

Expand Down
22 changes: 22 additions & 0 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession

from .repo_utils import query_type

from entities.models import (
FinancialInstitutionDao,
FinancialInstitutionDomainDao,
FinancialInstitutionDto,
FinancialInsitutionDomainCreate,
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
DeniedDomainDao,
AddressStateDao,
FederalRegulatorDao,
)


Expand Down Expand Up @@ -45,6 +51,22 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti
return await session.scalar(stmt)


async def get_sbl_types(session: AsyncSession) -> List[SBLInstitutionTypeDao]:
return await query_type(session, SBLInstitutionTypeDao)


async def get_hmda_types(session: AsyncSession) -> List[HMDAInstitutionTypeDao]:
return await query_type(session, HMDAInstitutionTypeDao)


async def get_address_states(session: AsyncSession) -> List[AddressStateDao]:
return await query_type(session, AddressStateDao)


async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulatorDao]:
return await query_type(session, FederalRegulatorDao)


async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao:
async with session.begin():
stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei)
Expand Down
12 changes: 12 additions & 0 deletions src/entities/repos/repo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, TypeVar

T = TypeVar("T")


async def query_type(session: AsyncSession, type: T) -> List[T]:
async with session.begin():
stmt = select(type)
res = await session.scalars(stmt)
return res.all()
29 changes: 28 additions & 1 deletion src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from oauth2 import oauth2_admin
from util import Router
from dependencies import check_domain, parse_leis, get_email_domain
from typing import Annotated, List, Tuple
from typing import Annotated, List, Tuple, Literal
from entities.engine import get_session
from entities.repos import institutions_repo as repo
from entities.models import (
Expand All @@ -12,11 +12,16 @@
FinancialInsitutionDomainDto,
FinancialInsitutionDomainCreate,
FinanicialInstitutionAssociationDto,
InstitutionTypeDto,
AuthenticatedUser,
AddressStateDto,
FederalRegulatorDto,
)
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.authentication import requires

InstitutionType = Literal["sbl", "hmda"]


async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]):
request.state.db_session = session
Expand Down Expand Up @@ -63,6 +68,28 @@ async def get_associated_institutions(request: Request):
]


@router.get("/types/{type}", response_model=List[InstitutionTypeDto])
@requires("authenticated")
async def get_institution_types(request: Request, type: InstitutionType):
match type:
case "sbl":
return await repo.get_sbl_types(request.state.db_session)
case "hmda":
return await repo.get_hmda_types(request.state.db_session)


@router.get("/address-states", response_model=List[AddressStateDto])
@requires("authenticated")
async def get_address_states(request: Request):
return await repo.get_address_states(request.state.db_session)


@router.get("/regulators", response_model=List[FederalRegulatorDto])
@requires("authenticated")
async def get_federal_regulators(request: Request):
return await repo.get_federal_regulators(request.state.db_session)


@router.get("/{lei}", response_model=FinancialInstitutionWithRelationsDto)
@requires("authenticated")
async def get_institution(
Expand Down
29 changes: 29 additions & 0 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,32 @@ def test_get_associated_institutions_with_no_institutions(
assert res.status_code == 200
get_institutions_mock.assert_called_once_with(ANY, [])
assert res.json() == []

def test_get_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.get_sbl_types")
mock.return_value = []
client = TestClient(app_fixture)
res = client.get("/v1/institutions/types/sbl")
assert res.status_code == 200

mock = mocker.patch("entities.repos.institutions_repo.get_hmda_types")
mock.return_value = []
res = client.get("/v1/institutions/types/hmda")
assert res.status_code == 200

res = client.get("/v1/institutions/types/blah")
assert res.status_code == 422

def test_get_address_states(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.get_address_states")
mock.return_value = []
client = TestClient(app_fixture)
res = client.get("/v1/institutions/address-states")
assert res.status_code == 200

def test_get_federal_regulators(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.get_federal_regulators")
mock.return_value = []
client = TestClient(app_fixture)
res = client.get("/v1/institutions/regulators")
assert res.status_code == 200
24 changes: 24 additions & 0 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ async def setup(
transaction_session.add(fi_dao_sub_456)
await transaction_session.commit()

async def test_get_sbl_types(self, query_session: AsyncSession):
expected_ids = {"SIT1", "SIT2", "SIT3"}
res = await repo.get_sbl_types(query_session)
assert len(res) == 3
assert set([r.id for r in res]) == expected_ids

async def test_get_hmda_types(self, query_session: AsyncSession):
expected_ids = {"HIT1", "HIT2", "HIT3"}
res = await repo.get_hmda_types(query_session)
assert len(res) == 3
assert set([r.id for r in res]) == expected_ids

async def test_get_address_states(self, query_session: AsyncSession):
expected_codes = {"CA", "GA", "FL"}
res = await repo.get_address_states(query_session)
assert len(res) == 3
assert set([r.code for r in res]) == expected_codes

async def test_get_federal_regulators(self, query_session: AsyncSession):
expected_ids = {"FRI1", "FRI2", "FRI3"}
res = await repo.get_federal_regulators(query_session)
assert len(res) == 3
assert set([r.id for r in res]) == expected_ids

async def test_get_institutions(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session)
assert len(res) == 3
Expand Down

0 comments on commit 3bef614

Please sign in to comment.