Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add patch for sbl institution types #110

Merged
merged 4 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"SblTypeMappingDao",
"SblTypeAssociationDto",
"SblTypeAssociationDetailsDto",
"SblTypeAssociationPatchDto",
"VersionedData",
]

from .dao import (
Expand All @@ -46,4 +48,6 @@
AddressStateDto,
SblTypeAssociationDto,
SblTypeAssociationDetailsDto,
SblTypeAssociationPatchDto,
VersionedData,
)
6 changes: 6 additions & 0 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class SblTypeMappingDao(Base):
details: Mapped[str] = mapped_column(nullable=True)
modified_by: Mapped[str] = mapped_column()

def __eq__(self, other: "SblTypeMappingDao") -> bool:
return self.lei == other.lei and self.type_id == other.type_id and self.details == other.details

def __hash__(self) -> int:
return hash((self.lei, self.type_id, self.details))

def as_db_dict(self):
data = {}
for attr, column in inspect(self.__class__).c.items():
Expand Down
14 changes: 13 additions & 1 deletion src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from typing import List, Set
from typing import Generic, List, Set, Sequence
from pydantic import BaseModel, model_validator
from typing import TypeVar

T = TypeVar("T")


class VersionedData(BaseModel, Generic[T]):
version: int
data: T


class FinancialInsitutionDomainBase(BaseModel):
Expand Down Expand Up @@ -45,6 +53,10 @@ class Config:
from_attributes = True


class SblTypeAssociationPatchDto(BaseModel):
sbl_institution_types: Sequence[SblTypeAssociationDto | str]


class FinancialInstitutionDto(FinancialInstitutionBase):
tax_id: str | None = None
rssd_id: int | None = None
Expand Down
52 changes: 35 additions & 17 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List
from typing import List, Sequence, Set

from sqlalchemy import select, func
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession

from regtech_api_commons.models import AuthenticatedUser

from .repo_utils import query_type
from .repo_utils import get_associated_sbl_types, query_type

from entities.models import (
FinancialInstitutionDao,
Expand All @@ -18,17 +18,17 @@
DeniedDomainDao,
AddressStateDao,
FederalRegulatorDao,
SblTypeMappingDao,
SblTypeAssociationDto,
)


async def get_institutions(
session: AsyncSession,
leis: List[str] = None,
leis: List[str] | None = None,
domain: str = "",
page: int = 0,
count: int = 100,
) -> List[FinancialInstitutionDao]:
) -> Sequence[FinancialInstitutionDao]:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
Expand All @@ -44,7 +44,7 @@ async def get_institutions(
return res.unique().all()


async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao:
async def get_institution(session: AsyncSession, lei: str) -> FinancialInstitutionDao | None:
async with session.begin():
stmt = (
select(FinancialInstitutionDao)
Expand All @@ -54,19 +54,19 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti
return await session.scalar(stmt)


async def get_sbl_types(session: AsyncSession) -> List[SBLInstitutionTypeDao]:
async def get_sbl_types(session: AsyncSession) -> Sequence[SBLInstitutionTypeDao]:
return await query_type(session, SBLInstitutionTypeDao)


async def get_hmda_types(session: AsyncSession) -> List[HMDAInstitutionTypeDao]:
async def get_hmda_types(session: AsyncSession) -> Sequence[HMDAInstitutionTypeDao]:
return await query_type(session, HMDAInstitutionTypeDao)


async def get_address_states(session: AsyncSession) -> List[AddressStateDao]:
async def get_address_states(session: AsyncSession) -> Sequence[AddressStateDao]:
return await query_type(session, AddressStateDao)


async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulatorDao]:
async def get_federal_regulators(session: AsyncSession) -> Sequence[FederalRegulatorDao]:
return await query_type(session, FederalRegulatorDao)


Expand All @@ -79,12 +79,7 @@ async def upsert_institution(
fi_data.pop("version", None)

if "sbl_institution_types" in fi_data:
types_association = [
SblTypeMappingDao(type_id=t, lei=fi.lei, modified_by=user.id)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details, lei=fi.lei, modified_by=user.id)
for t in fi.sbl_institution_types
]
types_association = get_associated_sbl_types(fi.lei, user.id, fi.sbl_institution_types)
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id))
Expand All @@ -93,9 +88,32 @@ async def upsert_institution(
return db_fi


async def update_sbl_types(
session: AsyncSession, user: AuthenticatedUser, lei: str, sbl_types: Sequence[SblTypeAssociationDto | str]
) -> FinancialInstitutionDao | None:
fi = await get_institution(session, lei)
if fi:
new_types = set(get_associated_sbl_types(lei, user.id, sbl_types))
old_types = set(fi.sbl_institution_types)
add_types = new_types.difference(old_types)
remove_types = old_types.difference(new_types)

fi.sbl_institution_types = [type for type in fi.sbl_institution_types if type not in remove_types]
fi.sbl_institution_types.extend(add_types)
for type in fi.sbl_institution_types:
type.version = fi.version
await session.commit()
"""
load the async relational attributes so dto can be properly serialized
"""
for type in fi.sbl_institution_types:
await type.awaitable_attrs.sbl_type
return fi


async def add_domains(
session: AsyncSession, lei: str, domains: List[FinancialInsitutionDomainCreate]
) -> List[FinancialInstitutionDomainDao]:
) -> Set[FinancialInstitutionDomainDao]:
async with session.begin():
daos = set(
map(
Expand Down
18 changes: 15 additions & 3 deletions src/entities/repos/repo_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, TypeVar
from typing import Sequence, TypeVar, Type
from entities.models import Base, SblTypeMappingDao, SblTypeAssociationDto

T = TypeVar("T")
T = TypeVar("T", bound=Base)


async def query_type(session: AsyncSession, type: T) -> List[T]:
async def query_type(session: AsyncSession, type: Type[T]) -> Sequence[T]:
async with session.begin():
stmt = select(type)
res = await session.scalars(stmt)
return res.all()


def get_associated_sbl_types(
lei: str, user_id: str, types: Sequence[SblTypeAssociationDto | str]
) -> Sequence[SblTypeMappingDao]:
return [
SblTypeMappingDao(type_id=t, lei=lei, modified_by=user_id)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details, lei=lei, modified_by=user_id)
for t in types
]
27 changes: 27 additions & 0 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
InstitutionTypeDto,
AddressStateDto,
FederalRegulatorDto,
SblTypeAssociationDetailsDto,
SblTypeAssociationPatchDto,
VersionedData,
)
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.authentication import requires
Expand Down Expand Up @@ -105,6 +108,30 @@ async def get_institution(
return res


@router.get("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
async def get_types(request: Request, lei: str, type: InstitutionType):
match type:
case "sbl":
fi = await repo.get_institution(request.state.db_session, lei)
return VersionedData(version=fi.version, data=fi.sbl_institution_types) if fi else None
case "hmda":
raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED, detail="HMDA type not yet supported")


@router.put("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
async def update_types(request: Request, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto):
match type:
case "sbl":
fi = await repo.update_sbl_types(
request.state.db_session, request.user, lei, types_patch.sbl_institution_types
)
return VersionedData(version=fi.version, data=fi.sbl_institution_types) if fi else None
case "hmda":
raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED, detail="HMDA type not yet supported")


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the question I put in Mattermost when you lost power:

Wondering if we should be returning 204 NO_CONTENT instead of None? I've done that with a few of the endpoints in filing, and added it to the wiki for others to bring up if we want to decide on a consistent approach

Thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap, the sbl case I've updated to use no_content instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmda one is a question mark at the moment; if we do allow more open modifications later down the line, I'm guessing the hmda one will be more like the other normal fields, rather than the special case types dealing

@router.post("/{lei}/domains/", response_model=List[FinancialInsitutionDomainDto], dependencies=[Depends(check_domain)])
@requires(["query-groups", "manage-users"])
async def add_domains(
Expand Down
83 changes: 83 additions & 0 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from http import HTTPStatus
from unittest.mock import Mock, ANY

from fastapi import FastAPI
Expand All @@ -13,6 +14,7 @@
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
SblTypeAssociationDto,
)


Expand Down Expand Up @@ -425,3 +427,84 @@ def test_get_federal_regulators(self, mocker: MockerFixture, app_fixture: FastAP
client = TestClient(app_fixture)
res = client.get("/v1/institutions/regulators")
assert res.status_code == 200

def test_get_sbl_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
inst_version = 2
get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution")
get_institution_mock.return_value = FinancialInstitutionDao(
version=inst_version,
name="Test Bank 123",
lei="TESTBANK123",
is_active=True,
domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")],
tax_id="123456789",
rssd_id=1234,
primary_federal_regulator_id="FRI1",
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
hq_address_state_code="GA",
hq_address_state=AddressStateDao(code="GA", name="Georgia"),
hq_address_zip="00000",
parent_lei="PARENTTESTBANK123",
parent_legal_name="PARENT TEST BANK 123",
parent_rssd_id=12345,
top_holder_lei="TOPHOLDERLEI123",
top_holder_legal_name="TOP HOLDER LEI 123",
top_holder_rssd_id=123456,
)
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.get(f"/v1/institutions/{test_lei}/types/sbl")
assert res.status_code == HTTPStatus.OK
result = res.json()
assert len(result["data"]) == 1
assert result["version"] == inst_version
assert result["data"][0] == {"sbl_type": {"id": "SIT1", "name": "SIT1"}, "details": None}

def test_get_hmda_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.get(f"/v1/institutions/{test_lei}/types/hmda")
assert res.status_code == HTTPStatus.NOT_IMPLEMENTED

def test_update_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.put(
f"/v1/institutions/{test_lei}/types/sbl",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.OK
mock.assert_called_once_with(
ANY, ANY, test_lei, ["1", SblTypeAssociationDto(id="2"), SblTypeAssociationDto(id="13", details="test")]
)

def test_update_unsupported_institution_types(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.put(
f"/v1/institutions/{test_lei}/types/hmda",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.NOT_IMPLEMENTED
mock.assert_not_called()

def test_update_wrong_institution_types(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
mock = mocker.patch("entities.repos.institutions_repo.update_sbl_types")
client = TestClient(app_fixture)
test_lei = "TESTBANK123"
res = client.put(
f"/v1/institutions/{test_lei}/types/test",
json={"sbl_institution_types": ["1", {"id": "2"}, {"id": "13", "details": "test"}]},
)
assert res.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
mock.assert_not_called()
33 changes: 33 additions & 0 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncSession

from entities.models import (
Expand Down Expand Up @@ -333,3 +334,35 @@ async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSess
async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, leis=["TESTBANK456"])
assert res[0].sbl_institution_types[0].sbl_type.name != "Test SBL Instituion ID 1"

async def test_update_sbl_institution_types(
self, mocker: MockerFixture, query_session: AsyncSession, transaction_session: AsyncSession
):
test_lei = "TESTBANK123"
existing_inst = await repo.get_institution(query_session, test_lei)
sbl_types = [
SblTypeAssociationDto(id="1"),
SblTypeAssociationDto(id="2"),
SblTypeAssociationDto(id="13", details="test"),
]
commit_spy = mocker.patch.object(transaction_session, "commit", wraps=transaction_session.commit)
updated_inst = await repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types)
commit_spy.assert_called_once()
assert len(existing_inst.sbl_institution_types) == 1
assert len(updated_inst.sbl_institution_types) == 3
diffs = set(updated_inst.sbl_institution_types).difference(set(existing_inst.sbl_institution_types))
assert len(diffs) == 2

async def test_update_sbl_institution_types_inst_non_exist(
self, mocker: MockerFixture, transaction_session: AsyncSession
):
test_lei = "NONEXISTINGBANK"
sbl_types = [
SblTypeAssociationDto(id="1"),
SblTypeAssociationDto(id="2"),
SblTypeAssociationDto(id="13", details="test"),
]
commit_spy = mocker.patch.object(transaction_session, "commit", wraps=transaction_session.commit)
res = await repo.update_sbl_types(transaction_session, self.auth_user, test_lei, sbl_types)
commit_spy.assert_not_called()
assert res is None
Loading