Skip to content

Commit

Permalink
Feature/89 add details in mapping model (#92)
Browse files Browse the repository at this point in the history
closes #89
  • Loading branch information
lchen-2101 authored Jan 17, 2024
1 parent 21c52a1 commit 9212ba0
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""240111 add free form text for sbl types
Revision ID: 6826f05140cd
Revises: ada681e1877f
Create Date: 2024-01-11 14:33:56.518611
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "6826f05140cd"
down_revision: Union[str, None] = "ada681e1877f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column("fi_to_type_mapping", sa.Column("details", type_=sa.String(), nullable=True))


def downgrade() -> None:
op.drop_column("fi_to_type_mapping", "details")
6 changes: 6 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"FederalRegulatorDto",
"InstitutionTypeDto",
"AddressStateDto",
"SblTypeMappingDao",
"SblTypeAssociationDto",
"SblTypeAssociationDetailsDto",
]

from .dao import (
Expand All @@ -29,6 +32,7 @@
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
AddressStateDao,
SblTypeMappingDao,
)
from .dto import (
FinancialInstitutionDto,
Expand All @@ -42,4 +46,6 @@
FederalRegulatorDto,
InstitutionTypeDto,
AddressStateDto,
SblTypeAssociationDto,
SblTypeAssociationDetailsDto,
)
20 changes: 8 additions & 12 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from datetime import datetime
from typing import List
from sqlalchemy import ForeignKey, func, String, Table, Column
from sqlalchemy import ForeignKey, func, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy


class Base(AsyncAttrs, DeclarativeBase):
Expand All @@ -15,12 +14,12 @@ class AuditMixin(object):
event_time: Mapped[datetime] = mapped_column(server_default=func.now())


fi_to_type_mapping = Table(
"fi_to_type_mapping",
Base.metadata,
Column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True),
Column("type_id", ForeignKey("sbl_institution_type.id"), primary_key=True),
)
class SblTypeMappingDao(Base):
__tablename__ = "fi_to_type_mapping"
lei: Mapped[str] = mapped_column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True)
type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), primary_key=True)
sbl_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin")
details: Mapped[str] = mapped_column(nullable=True)


class FinancialInstitutionDao(AuditMixin, Base):
Expand All @@ -37,10 +36,7 @@ class FinancialInstitutionDao(AuditMixin, Base):
primary_federal_regulator: Mapped["FederalRegulatorDao"] = relationship(lazy="selectin")
hmda_institution_type_id: Mapped[str] = mapped_column(ForeignKey("hmda_institution_type.id"), nullable=True)
hmda_institution_type: Mapped["HMDAInstitutionTypeDao"] = relationship(lazy="selectin")
sbl_institution_types: Mapped[List["SBLInstitutionTypeDao"]] = relationship(
lazy="selectin", secondary=fi_to_type_mapping
)
sbl_institution_type_ids: AssociationProxy[List[str]] = association_proxy("sbl_institution_types", "id")
sbl_institution_types: Mapped[List[SblTypeMappingDao]] = relationship(lazy="selectin", cascade="all, delete-orphan")
hq_address_street_1: Mapped[str]
hq_address_street_2: Mapped[str] = mapped_column(nullable=True)
hq_address_city: Mapped[str]
Expand Down
36 changes: 33 additions & 3 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Dict, Any, Set
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from starlette.authentication import BaseUser


Expand All @@ -24,12 +24,34 @@ class FinancialInstitutionBase(BaseModel):
is_active: bool


class SblTypeAssociationDto(BaseModel):
id: str
details: str | None = None

@model_validator(mode="after")
def validate_type(self) -> "SblTypeAssociationDto":
"""
Validates `Other` type and free form input.
If `Other` is selected, then details should be filled in;
vice versa if `Other` is not selected, then details should be null.
"""
other_type_id = "13"
if self.id == other_type_id and not self.details:
raise ValueError(f"SBL institution type '{other_type_id}' requires additional details.")
elif self.id != other_type_id:
self.details = None
return self

class Config:
from_attributes = True


class FinancialInstitutionDto(FinancialInstitutionBase):
tax_id: str | None = None
rssd_id: int | None = None
primary_federal_regulator_id: str | None = None
hmda_institution_type_id: str | None = None
sbl_institution_type_ids: List[str] = []
sbl_institution_types: List[SblTypeAssociationDto | str] = []
hq_address_street_1: str
hq_address_street_2: str | None = None
hq_address_city: str
Expand Down Expand Up @@ -81,6 +103,14 @@ class Config:
from_attributes = True


class SblTypeAssociationDetailsDto(BaseModel):
sbl_type: InstitutionTypeDto
details: str | None = None

class Config:
from_attributes = True


class AddressStateBase(BaseModel):
code: str

Expand All @@ -95,7 +125,7 @@ class Config:
class FinancialInstitutionWithRelationsDto(FinancialInstitutionDto):
primary_federal_regulator: FederalRegulatorDto | None = None
hmda_institution_type: InstitutionTypeDto | None = None
sbl_institution_types: List[InstitutionTypeDto] = []
sbl_institution_types: List[SblTypeAssociationDetailsDto] = []
hq_address_state: AddressStateDto
domains: List[FinancialInsitutionDomainDto] = []

Expand Down
18 changes: 9 additions & 9 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DeniedDomainDao,
AddressStateDao,
FederalRegulatorDao,
SblTypeMappingDao,
)


Expand Down Expand Up @@ -72,15 +73,14 @@ async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto)
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)

# Populate with model objects from SBLInstitutionTypeDao and clear out
# the id field since it's just a view
if "sbl_institution_type_ids" in fi_data:
sbl_type_stmt = select(SBLInstitutionTypeDao).filter(
SBLInstitutionTypeDao.id.in_(fi_data["sbl_institution_type_ids"])
)
sbl_types = await session.scalars(sbl_type_stmt)
fi_data["sbl_institution_types"] = sbl_types.all()
del fi_data["sbl_institution_type_ids"]
if "sbl_institution_types" in fi_data:
types_association = [
SblTypeMappingDao(type_id=t)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details)
for t in fi.sbl_institution_types
]
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data))
return await session.get(FinancialInstitutionDao, db_fi.lei)
Expand Down
3 changes: 2 additions & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AddressStateDao,
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
)


Expand Down Expand Up @@ -70,7 +71,7 @@ def get_institutions_mock(mocker: MockerFixture) -> Mock:
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand Down
29 changes: 25 additions & 4 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AddressStateDao,
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
)


Expand Down Expand Up @@ -47,7 +48,7 @@ def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: Fas
primary_federal_regulator=FederalRegulatorDao(id="FRI2", name="FRI2"),
hmda_institution_type_id="HIT2",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT2", name="HIT2"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT2", name="SIT2")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT2", name="SIT2"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand Down Expand Up @@ -140,6 +141,26 @@ def test_create_institution_missing_required_field(
)
assert res.status_code == 422

def test_create_institution_missing_sbl_type_free_form(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
client = TestClient(app_fixture)
res = client.post(
"/v1/institutions/",
json={
"name": "testName",
"lei": "testLei",
"is_active": True,
"hq_address_street_1": "Test Address Street 1",
"hq_address_city": "Test City 1",
"hq_address_state_code": "VA",
"hq_address_zip": "00000",
"sbl_institution_types": [{"id": "13"}],
},
)
assert res.status_code == 422
assert "requires additional details." in res.json()["detail"][0]["msg"]

def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock):
claims = {
"name": "test",
Expand Down Expand Up @@ -197,7 +218,7 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand Down Expand Up @@ -294,7 +315,7 @@ def test_get_associated_institutions(
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand All @@ -319,7 +340,7 @@ def test_get_associated_institutions(
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 2",
hq_address_street_2="",
hq_address_city="Test City 2",
Expand Down
24 changes: 13 additions & 11 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
FederalRegulatorDao,
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
SblTypeAssociationDto,
)
from entities.repos import institutions_repo as repo

Expand All @@ -39,9 +41,9 @@ async def setup(
HMDAInstitutionTypeDao(id="HIT3", name="Test HMDA Instituion ID 3"),
)
sbl_it_dao_sit1, sbl_it_dao_sit2, sbl_it_dao_sit3 = (
SBLInstitutionTypeDao(id="SIT1", name="Test SBL Instituion ID 1"),
SBLInstitutionTypeDao(id="SIT2", name="Test SBL Instituion ID 2"),
SBLInstitutionTypeDao(id="SIT3", name="Test SBL Instituion ID 3"),
SBLInstitutionTypeDao(id="1", name="Test SBL Instituion ID 1"),
SBLInstitutionTypeDao(id="2", name="Test SBL Instituion ID 2"),
SBLInstitutionTypeDao(id="13", name="Test SBL Instituion ID Other"),
)
fi_dao_123, fi_dao_456, fi_dao_sub_456 = (
FinancialInstitutionDao(
Expand All @@ -53,7 +55,7 @@ async def setup(
rssd_id=1234,
primary_federal_regulator_id="FRI1",
hmda_institution_type_id="HIT1",
sbl_institution_types=[sbl_it_dao_sit1],
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1)],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand All @@ -75,7 +77,7 @@ async def setup(
rssd_id=4321,
primary_federal_regulator_id="FRI2",
hmda_institution_type_id="HIT2",
sbl_institution_types=[sbl_it_dao_sit2],
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2)],
hq_address_street_1="Test Address Street 2",
hq_address_street_2="",
hq_address_city="Test City 2",
Expand All @@ -97,7 +99,7 @@ async def setup(
rssd_id=2134,
primary_federal_regulator_id="FRI3",
hmda_institution_type_id="HIT3",
sbl_institution_types=[sbl_it_dao_sit3],
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, details="test")],
hq_address_street_1="Test Address Street 3",
hq_address_street_2="",
hq_address_city="Test City 3",
Expand Down Expand Up @@ -134,7 +136,7 @@ async def setup(
await transaction_session.commit()

async def test_get_sbl_types(self, query_session: AsyncSession):
expected_ids = {"SIT1", "SIT2", "SIT3"}
expected_ids = {"1", "2", "13"}
res = await repo.get_sbl_types(query_session)
assert len(res) == 3
assert set([r.id for r in res]) == expected_ids
Expand Down Expand Up @@ -199,7 +201,7 @@ async def test_add_institution(self, transaction_session: AsyncSession):
rssd_id=6543,
primary_federal_regulator_id="FRI3",
hmda_institution_type_id="HIT3",
sbl_institution_type_ids=["SIT3"],
sbl_institution_types=[SblTypeAssociationDto(id="1")],
hq_address_street_1="Test Address Street 3",
hq_address_street_2="",
hq_address_city="Test City 3",
Expand All @@ -218,7 +220,7 @@ async def test_add_institution(self, transaction_session: AsyncSession):
assert len(res) == 4
new_sbl_types = next(iter([fi for fi in res if fi.lei == "NEWBANK123"])).sbl_institution_types
assert len(new_sbl_types) == 1
assert next(iter(new_sbl_types)).name == "Test SBL Instituion ID 3"
assert next(iter(new_sbl_types)).sbl_type.name == "Test SBL Instituion ID 1"

async def test_add_institution_only_required_fields(
self, transaction_session: AsyncSession, query_session: AsyncSession
Expand Down Expand Up @@ -313,8 +315,8 @@ async def test_institution_mapped_to_hmda_it_invalid(self, query_session: AsyncS

async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, leis=["TESTBANK123"])
assert res[0].sbl_institution_types[0].name == "Test SBL Instituion ID 1"
assert res[0].sbl_institution_types[0].sbl_type.name == "Test SBL Instituion ID 1"

async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, leis=["TESTBANK456"])
assert res[0].sbl_institution_types[0].name != "Test SBL Instituion ID 1"
assert res[0].sbl_institution_types[0].sbl_type.name != "Test SBL Instituion ID 1"
Loading

0 comments on commit 9212ba0

Please sign in to comment.