Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/89 add details in mapping model #92

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""240111 add free form text for sbl types

Revision ID: 6826f05140cd
Revises: ada681e1877f
Create Date: 2024-01-11 14:33:56.518611

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "6826f05140cd"
down_revision: Union[str, None] = "ada681e1877f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column("fi_to_type_mapping", sa.Column("details", type_=sa.String(), nullable=True))


def downgrade() -> None:
op.drop_column("fi_to_type_mapping", "details")
6 changes: 6 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"FederalRegulatorDto",
"InstitutionTypeDto",
"AddressStateDto",
"SblTypeMappingDao",
"SblTypeAssociationDto",
"SblTypeAssociationDetailsDto",
]

from .dao import (
Expand All @@ -29,6 +32,7 @@
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
AddressStateDao,
SblTypeMappingDao,
)
from .dto import (
FinancialInstitutionDto,
Expand All @@ -42,4 +46,6 @@
FederalRegulatorDto,
InstitutionTypeDto,
AddressStateDto,
SblTypeAssociationDto,
SblTypeAssociationDetailsDto,
)
20 changes: 8 additions & 12 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from datetime import datetime
from typing import List
from sqlalchemy import ForeignKey, func, String, Table, Column
from sqlalchemy import ForeignKey, func, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy


class Base(AsyncAttrs, DeclarativeBase):
Expand All @@ -15,12 +14,12 @@ class AuditMixin(object):
event_time: Mapped[datetime] = mapped_column(server_default=func.now())


fi_to_type_mapping = Table(
"fi_to_type_mapping",
Base.metadata,
Column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True),
Column("type_id", ForeignKey("sbl_institution_type.id"), primary_key=True),
)
class SblTypeMappingDao(Base):
__tablename__ = "fi_to_type_mapping"
lei: Mapped[str] = mapped_column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True)
type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), primary_key=True)
sbl_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin")
details: Mapped[str] = mapped_column(nullable=True)


class FinancialInstitutionDao(AuditMixin, Base):
Expand All @@ -37,10 +36,7 @@ class FinancialInstitutionDao(AuditMixin, Base):
primary_federal_regulator: Mapped["FederalRegulatorDao"] = relationship(lazy="selectin")
hmda_institution_type_id: Mapped[str] = mapped_column(ForeignKey("hmda_institution_type.id"), nullable=True)
hmda_institution_type: Mapped["HMDAInstitutionTypeDao"] = relationship(lazy="selectin")
sbl_institution_types: Mapped[List["SBLInstitutionTypeDao"]] = relationship(
lazy="selectin", secondary=fi_to_type_mapping
)
sbl_institution_type_ids: AssociationProxy[List[str]] = association_proxy("sbl_institution_types", "id")
sbl_institution_types: Mapped[List[SblTypeMappingDao]] = relationship(lazy="selectin", cascade="all, delete-orphan")
hq_address_street_1: Mapped[str]
hq_address_street_2: Mapped[str] = mapped_column(nullable=True)
hq_address_city: Mapped[str]
Expand Down
36 changes: 33 additions & 3 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Dict, Any, Set
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from starlette.authentication import BaseUser


Expand All @@ -24,12 +24,34 @@ class FinancialInstitutionBase(BaseModel):
is_active: bool


class SblTypeAssociationDto(BaseModel):
id: str
details: str | None = None

@model_validator(mode="after")
def validate_type(self) -> "SblTypeAssociationDto":
"""
Validates `Other` type and free form input.
If `Other` is selected, then details should be filled in;
vice versa if `Other` is not selected, then details should be null.
"""
other_type_id = "13"
if self.id == other_type_id and not self.details:
raise ValueError(f"SBL institution type '{other_type_id}' requires additional details.")
elif self.id != other_type_id:
self.details = None
return self

class Config:
from_attributes = True


class FinancialInstitutionDto(FinancialInstitutionBase):
tax_id: str | None = None
rssd_id: int | None = None
primary_federal_regulator_id: str | None = None
hmda_institution_type_id: str | None = None
sbl_institution_type_ids: List[str] = []
sbl_institution_types: List[SblTypeAssociationDto | str] = []
hq_address_street_1: str
hq_address_street_2: str | None = None
hq_address_city: str
Expand Down Expand Up @@ -81,6 +103,14 @@ class Config:
from_attributes = True


class SblTypeAssociationDetailsDto(BaseModel):
sbl_type: InstitutionTypeDto
details: str | None = None

class Config:
from_attributes = True


class AddressStateBase(BaseModel):
code: str

Expand All @@ -95,7 +125,7 @@ class Config:
class FinancialInstitutionWithRelationsDto(FinancialInstitutionDto):
primary_federal_regulator: FederalRegulatorDto | None = None
hmda_institution_type: InstitutionTypeDto | None = None
sbl_institution_types: List[InstitutionTypeDto] = []
sbl_institution_types: List[SblTypeAssociationDetailsDto] = []
hq_address_state: AddressStateDto
domains: List[FinancialInsitutionDomainDto] = []

Expand Down
18 changes: 9 additions & 9 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DeniedDomainDao,
AddressStateDao,
FederalRegulatorDao,
SblTypeMappingDao,
)


Expand Down Expand Up @@ -72,15 +73,14 @@ async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto)
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)

# Populate with model objects from SBLInstitutionTypeDao and clear out
# the id field since it's just a view
if "sbl_institution_type_ids" in fi_data:
sbl_type_stmt = select(SBLInstitutionTypeDao).filter(
SBLInstitutionTypeDao.id.in_(fi_data["sbl_institution_type_ids"])
)
sbl_types = await session.scalars(sbl_type_stmt)
fi_data["sbl_institution_types"] = sbl_types.all()
del fi_data["sbl_institution_type_ids"]
if "sbl_institution_types" in fi_data:
types_association = [
SblTypeMappingDao(type_id=t)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details)
for t in fi.sbl_institution_types
]
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data))
await session.flush([db_fi])
Expand Down
3 changes: 2 additions & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AddressStateDao,
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
)


Expand Down Expand Up @@ -70,7 +71,7 @@ def get_institutions_mock(mocker: MockerFixture) -> Mock:
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand Down
29 changes: 25 additions & 4 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AddressStateDao,
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
)


Expand Down Expand Up @@ -47,7 +48,7 @@ def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: Fas
primary_federal_regulator=FederalRegulatorDao(id="FRI2", name="FRI2"),
hmda_institution_type_id="HIT2",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT2", name="HIT2"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT2", name="SIT2")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT2", name="SIT2"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand Down Expand Up @@ -140,6 +141,26 @@ def test_create_institution_missing_required_field(
)
assert res.status_code == 422

def test_create_institution_missing_sbl_type_free_form(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
client = TestClient(app_fixture)
res = client.post(
"/v1/institutions/",
json={
"name": "testName",
"lei": "testLei",
"is_active": True,
"hq_address_street_1": "Test Address Street 1",
"hq_address_city": "Test City 1",
"hq_address_state_code": "VA",
"hq_address_zip": "00000",
"sbl_institution_types": [{"id": "13"}],
},
)
assert res.status_code == 422
assert "requires additional details." in res.json()["detail"][0]["msg"]

def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock):
claims = {
"name": "test",
Expand Down Expand Up @@ -197,7 +218,7 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand Down Expand Up @@ -294,7 +315,7 @@ def test_get_associated_institutions(
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand All @@ -319,7 +340,7 @@ def test_get_associated_institutions(
primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"),
hmda_institution_type_id="HIT1",
hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"),
sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")],
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 2",
hq_address_street_2="",
hq_address_city="Test City 2",
Expand Down
24 changes: 13 additions & 11 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
FederalRegulatorDao,
HMDAInstitutionTypeDao,
SBLInstitutionTypeDao,
SblTypeMappingDao,
SblTypeAssociationDto,
)
from entities.repos import institutions_repo as repo

Expand All @@ -39,9 +41,9 @@ async def setup(
HMDAInstitutionTypeDao(id="HIT3", name="Test HMDA Instituion ID 3"),
)
sbl_it_dao_sit1, sbl_it_dao_sit2, sbl_it_dao_sit3 = (
SBLInstitutionTypeDao(id="SIT1", name="Test SBL Instituion ID 1"),
SBLInstitutionTypeDao(id="SIT2", name="Test SBL Instituion ID 2"),
SBLInstitutionTypeDao(id="SIT3", name="Test SBL Instituion ID 3"),
SBLInstitutionTypeDao(id="1", name="Test SBL Instituion ID 1"),
SBLInstitutionTypeDao(id="2", name="Test SBL Instituion ID 2"),
SBLInstitutionTypeDao(id="13", name="Test SBL Instituion ID Other"),
)
fi_dao_123, fi_dao_456, fi_dao_sub_456 = (
FinancialInstitutionDao(
Expand All @@ -53,7 +55,7 @@ async def setup(
rssd_id=1234,
primary_federal_regulator_id="FRI1",
hmda_institution_type_id="HIT1",
sbl_institution_types=[sbl_it_dao_sit1],
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1)],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
Expand All @@ -75,7 +77,7 @@ async def setup(
rssd_id=4321,
primary_federal_regulator_id="FRI2",
hmda_institution_type_id="HIT2",
sbl_institution_types=[sbl_it_dao_sit2],
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2)],
hq_address_street_1="Test Address Street 2",
hq_address_street_2="",
hq_address_city="Test City 2",
Expand All @@ -97,7 +99,7 @@ async def setup(
rssd_id=2134,
primary_federal_regulator_id="FRI3",
hmda_institution_type_id="HIT3",
sbl_institution_types=[sbl_it_dao_sit3],
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, details="test")],
hq_address_street_1="Test Address Street 3",
hq_address_street_2="",
hq_address_city="Test City 3",
Expand Down Expand Up @@ -134,7 +136,7 @@ async def setup(
await transaction_session.commit()

async def test_get_sbl_types(self, query_session: AsyncSession):
expected_ids = {"SIT1", "SIT2", "SIT3"}
expected_ids = {"1", "2", "13"}
res = await repo.get_sbl_types(query_session)
assert len(res) == 3
assert set([r.id for r in res]) == expected_ids
Expand Down Expand Up @@ -199,7 +201,7 @@ async def test_add_institution(self, transaction_session: AsyncSession):
rssd_id=6543,
primary_federal_regulator_id="FRI3",
hmda_institution_type_id="HIT3",
sbl_institution_type_ids=["SIT3"],
sbl_institution_types=[SblTypeAssociationDto(id="1")],
hq_address_street_1="Test Address Street 3",
hq_address_street_2="",
hq_address_city="Test City 3",
Expand All @@ -218,7 +220,7 @@ async def test_add_institution(self, transaction_session: AsyncSession):
assert len(res) == 4
new_sbl_types = next(iter([fi for fi in res if fi.lei == "NEWBANK123"])).sbl_institution_types
assert len(new_sbl_types) == 1
assert next(iter(new_sbl_types)).name == "Test SBL Instituion ID 3"
assert next(iter(new_sbl_types)).sbl_type.name == "Test SBL Instituion ID 1"

async def test_add_institution_only_required_fields(
self, transaction_session: AsyncSession, query_session: AsyncSession
Expand Down Expand Up @@ -313,8 +315,8 @@ async def test_institution_mapped_to_hmda_it_invalid(self, query_session: AsyncS

async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, leis=["TESTBANK123"])
assert res[0].sbl_institution_types[0].name == "Test SBL Instituion ID 1"
assert res[0].sbl_institution_types[0].sbl_type.name == "Test SBL Instituion ID 1"

async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, leis=["TESTBANK456"])
assert res[0].sbl_institution_types[0].name != "Test SBL Instituion ID 1"
assert res[0].sbl_institution_types[0].sbl_type.name != "Test SBL Instituion ID 1"
Loading
Loading