diff --git a/db_revisions/versions/6826f05140cd_240111_add_free_form_text_for_sbl_types.py b/db_revisions/versions/6826f05140cd_240111_add_free_form_text_for_sbl_types.py new file mode 100644 index 0000000..6f4264f --- /dev/null +++ b/db_revisions/versions/6826f05140cd_240111_add_free_form_text_for_sbl_types.py @@ -0,0 +1,26 @@ +"""240111 add free form text for sbl types + +Revision ID: 6826f05140cd +Revises: ada681e1877f +Create Date: 2024-01-11 14:33:56.518611 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "6826f05140cd" +down_revision: Union[str, None] = "ada681e1877f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("fi_to_type_mapping", sa.Column("details", type_=sa.String(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("fi_to_type_mapping", "details") diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py index 0394759..64feb9b 100644 --- a/src/entities/models/__init__.py +++ b/src/entities/models/__init__.py @@ -18,6 +18,9 @@ "FederalRegulatorDto", "InstitutionTypeDto", "AddressStateDto", + "SblTypeMappingDao", + "SblTypeAssociationDto", + "SblTypeAssociationDetailsDto", ] from .dao import ( @@ -29,6 +32,7 @@ HMDAInstitutionTypeDao, SBLInstitutionTypeDao, AddressStateDao, + SblTypeMappingDao, ) from .dto import ( FinancialInstitutionDto, @@ -42,4 +46,6 @@ FederalRegulatorDto, InstitutionTypeDto, AddressStateDto, + SblTypeAssociationDto, + SblTypeAssociationDetailsDto, ) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 3cf3bb7..031e687 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -1,10 +1,9 @@ from datetime import datetime from typing import List -from sqlalchemy import ForeignKey, func, String, Table, Column +from sqlalchemy import ForeignKey, func, String from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase from sqlalchemy.ext.asyncio import AsyncAttrs -from sqlalchemy.ext.associationproxy import association_proxy, AssociationProxy class Base(AsyncAttrs, DeclarativeBase): @@ -15,12 +14,12 @@ class AuditMixin(object): event_time: Mapped[datetime] = mapped_column(server_default=func.now()) -fi_to_type_mapping = Table( - "fi_to_type_mapping", - Base.metadata, - Column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True), - Column("type_id", ForeignKey("sbl_institution_type.id"), primary_key=True), -) +class SblTypeMappingDao(Base): + __tablename__ = "fi_to_type_mapping" + lei: Mapped[str] = mapped_column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True) + type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), primary_key=True) + sbl_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin") + details: Mapped[str] = mapped_column(nullable=True) class FinancialInstitutionDao(AuditMixin, Base): @@ -37,10 +36,7 @@ class FinancialInstitutionDao(AuditMixin, Base): primary_federal_regulator: Mapped["FederalRegulatorDao"] = relationship(lazy="selectin") hmda_institution_type_id: Mapped[str] = mapped_column(ForeignKey("hmda_institution_type.id"), nullable=True) hmda_institution_type: Mapped["HMDAInstitutionTypeDao"] = relationship(lazy="selectin") - sbl_institution_types: Mapped[List["SBLInstitutionTypeDao"]] = relationship( - lazy="selectin", secondary=fi_to_type_mapping - ) - sbl_institution_type_ids: AssociationProxy[List[str]] = association_proxy("sbl_institution_types", "id") + sbl_institution_types: Mapped[List[SblTypeMappingDao]] = relationship(lazy="selectin", cascade="all, delete-orphan") hq_address_street_1: Mapped[str] hq_address_street_2: Mapped[str] = mapped_column(nullable=True) hq_address_city: Mapped[str] diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index 275cbef..04ac4dc 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -1,5 +1,5 @@ from typing import List, Dict, Any, Set -from pydantic import BaseModel +from pydantic import BaseModel, model_validator from starlette.authentication import BaseUser @@ -24,12 +24,34 @@ class FinancialInstitutionBase(BaseModel): is_active: bool +class SblTypeAssociationDto(BaseModel): + id: str + details: str | None = None + + @model_validator(mode="after") + def validate_type(self) -> "SblTypeAssociationDto": + """ + Validates `Other` type and free form input. + If `Other` is selected, then details should be filled in; + vice versa if `Other` is not selected, then details should be null. + """ + other_type_id = "13" + if self.id == other_type_id and not self.details: + raise ValueError(f"SBL institution type '{other_type_id}' requires additional details.") + elif self.id != other_type_id: + self.details = None + return self + + class Config: + from_attributes = True + + class FinancialInstitutionDto(FinancialInstitutionBase): tax_id: str | None = None rssd_id: int | None = None primary_federal_regulator_id: str | None = None hmda_institution_type_id: str | None = None - sbl_institution_type_ids: List[str] = [] + sbl_institution_types: List[SblTypeAssociationDto | str] = [] hq_address_street_1: str hq_address_street_2: str | None = None hq_address_city: str @@ -81,6 +103,14 @@ class Config: from_attributes = True +class SblTypeAssociationDetailsDto(BaseModel): + sbl_type: InstitutionTypeDto + details: str | None = None + + class Config: + from_attributes = True + + class AddressStateBase(BaseModel): code: str @@ -95,7 +125,7 @@ class Config: class FinancialInstitutionWithRelationsDto(FinancialInstitutionDto): primary_federal_regulator: FederalRegulatorDto | None = None hmda_institution_type: InstitutionTypeDto | None = None - sbl_institution_types: List[InstitutionTypeDto] = [] + sbl_institution_types: List[SblTypeAssociationDetailsDto] = [] hq_address_state: AddressStateDto domains: List[FinancialInsitutionDomainDto] = [] diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 0b2e212..b9b6dab 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -16,6 +16,7 @@ DeniedDomainDao, AddressStateDao, FederalRegulatorDao, + SblTypeMappingDao, ) @@ -72,15 +73,14 @@ async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) fi_data = fi.__dict__.copy() fi_data.pop("_sa_instance_state", None) - # Populate with model objects from SBLInstitutionTypeDao and clear out - # the id field since it's just a view - if "sbl_institution_type_ids" in fi_data: - sbl_type_stmt = select(SBLInstitutionTypeDao).filter( - SBLInstitutionTypeDao.id.in_(fi_data["sbl_institution_type_ids"]) - ) - sbl_types = await session.scalars(sbl_type_stmt) - fi_data["sbl_institution_types"] = sbl_types.all() - del fi_data["sbl_institution_type_ids"] + if "sbl_institution_types" in fi_data: + types_association = [ + SblTypeMappingDao(type_id=t) + if isinstance(t, str) + else SblTypeMappingDao(type_id=t.id, details=t.details) + for t in fi.sbl_institution_types + ] + fi_data["sbl_institution_types"] = types_association db_fi = await session.merge(FinancialInstitutionDao(**fi_data)) return await session.get(FinancialInstitutionDao, db_fi.lei) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 80a525d..19f4b14 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -13,6 +13,7 @@ AddressStateDao, HMDAInstitutionTypeDao, SBLInstitutionTypeDao, + SblTypeMappingDao, ) @@ -70,7 +71,7 @@ def get_institutions_mock(mocker: MockerFixture) -> Mock: primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), hmda_institution_type_id="HIT1", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), - sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")], + sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index 9af99dd..e9dd030 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -12,6 +12,7 @@ AddressStateDao, HMDAInstitutionTypeDao, SBLInstitutionTypeDao, + SblTypeMappingDao, ) @@ -47,7 +48,7 @@ def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: Fas primary_federal_regulator=FederalRegulatorDao(id="FRI2", name="FRI2"), hmda_institution_type_id="HIT2", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT2", name="HIT2"), - sbl_institution_types=[SBLInstitutionTypeDao(id="SIT2", name="SIT2")], + sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT2", name="SIT2"))], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -140,6 +141,26 @@ def test_create_institution_missing_required_field( ) assert res.status_code == 422 + def test_create_institution_missing_sbl_type_free_form( + self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock + ): + client = TestClient(app_fixture) + res = client.post( + "/v1/institutions/", + json={ + "name": "testName", + "lei": "testLei", + "is_active": True, + "hq_address_street_1": "Test Address Street 1", + "hq_address_city": "Test City 1", + "hq_address_state_code": "VA", + "hq_address_zip": "00000", + "sbl_institution_types": [{"id": "13"}], + }, + ) + assert res.status_code == 422 + assert "requires additional details." in res.json()["detail"][0]["msg"] + def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): claims = { "name": "test", @@ -197,7 +218,7 @@ def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAP primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), hmda_institution_type_id="HIT1", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), - sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")], + sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -294,7 +315,7 @@ def test_get_associated_institutions( primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), hmda_institution_type_id="HIT1", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), - sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")], + sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -319,7 +340,7 @@ def test_get_associated_institutions( primary_federal_regulator=FederalRegulatorDao(id="FRI1", name="FRI1"), hmda_institution_type_id="HIT1", hmda_institution_type=HMDAInstitutionTypeDao(id="HIT1", name="HIT1"), - sbl_institution_types=[SBLInstitutionTypeDao(id="SIT1", name="SIT1")], + sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))], hq_address_street_1="Test Address Street 2", hq_address_street_2="", hq_address_city="Test City 2", diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 29cc3f0..ce18766 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -13,6 +13,8 @@ FederalRegulatorDao, HMDAInstitutionTypeDao, SBLInstitutionTypeDao, + SblTypeMappingDao, + SblTypeAssociationDto, ) from entities.repos import institutions_repo as repo @@ -39,9 +41,9 @@ async def setup( HMDAInstitutionTypeDao(id="HIT3", name="Test HMDA Instituion ID 3"), ) sbl_it_dao_sit1, sbl_it_dao_sit2, sbl_it_dao_sit3 = ( - SBLInstitutionTypeDao(id="SIT1", name="Test SBL Instituion ID 1"), - SBLInstitutionTypeDao(id="SIT2", name="Test SBL Instituion ID 2"), - SBLInstitutionTypeDao(id="SIT3", name="Test SBL Instituion ID 3"), + SBLInstitutionTypeDao(id="1", name="Test SBL Instituion ID 1"), + SBLInstitutionTypeDao(id="2", name="Test SBL Instituion ID 2"), + SBLInstitutionTypeDao(id="13", name="Test SBL Instituion ID Other"), ) fi_dao_123, fi_dao_456, fi_dao_sub_456 = ( FinancialInstitutionDao( @@ -53,7 +55,7 @@ async def setup( rssd_id=1234, primary_federal_regulator_id="FRI1", hmda_institution_type_id="HIT1", - sbl_institution_types=[sbl_it_dao_sit1], + sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1)], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -75,7 +77,7 @@ async def setup( rssd_id=4321, primary_federal_regulator_id="FRI2", hmda_institution_type_id="HIT2", - sbl_institution_types=[sbl_it_dao_sit2], + sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2)], hq_address_street_1="Test Address Street 2", hq_address_street_2="", hq_address_city="Test City 2", @@ -97,7 +99,7 @@ async def setup( rssd_id=2134, primary_federal_regulator_id="FRI3", hmda_institution_type_id="HIT3", - sbl_institution_types=[sbl_it_dao_sit3], + sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, details="test")], hq_address_street_1="Test Address Street 3", hq_address_street_2="", hq_address_city="Test City 3", @@ -134,7 +136,7 @@ async def setup( await transaction_session.commit() async def test_get_sbl_types(self, query_session: AsyncSession): - expected_ids = {"SIT1", "SIT2", "SIT3"} + expected_ids = {"1", "2", "13"} res = await repo.get_sbl_types(query_session) assert len(res) == 3 assert set([r.id for r in res]) == expected_ids @@ -199,7 +201,7 @@ async def test_add_institution(self, transaction_session: AsyncSession): rssd_id=6543, primary_federal_regulator_id="FRI3", hmda_institution_type_id="HIT3", - sbl_institution_type_ids=["SIT3"], + sbl_institution_types=[SblTypeAssociationDto(id="1")], hq_address_street_1="Test Address Street 3", hq_address_street_2="", hq_address_city="Test City 3", @@ -218,7 +220,7 @@ async def test_add_institution(self, transaction_session: AsyncSession): assert len(res) == 4 new_sbl_types = next(iter([fi for fi in res if fi.lei == "NEWBANK123"])).sbl_institution_types assert len(new_sbl_types) == 1 - assert next(iter(new_sbl_types)).name == "Test SBL Instituion ID 3" + assert next(iter(new_sbl_types)).sbl_type.name == "Test SBL Instituion ID 1" async def test_add_institution_only_required_fields( self, transaction_session: AsyncSession, query_session: AsyncSession @@ -313,8 +315,8 @@ async def test_institution_mapped_to_hmda_it_invalid(self, query_session: AsyncS async def test_institution_mapped_to_sbl_it_valid(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, leis=["TESTBANK123"]) - assert res[0].sbl_institution_types[0].name == "Test SBL Instituion ID 1" + assert res[0].sbl_institution_types[0].sbl_type.name == "Test SBL Instituion ID 1" async def test_institution_mapped_to_sbl_it_invalid(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, leis=["TESTBANK456"]) - assert res[0].sbl_institution_types[0].name != "Test SBL Instituion ID 1" + assert res[0].sbl_institution_types[0].sbl_type.name != "Test SBL Instituion ID 1" diff --git a/tests/migrations/test_schema.py b/tests/migrations/test_schema.py index 09b6ac4..88273ba 100644 --- a/tests/migrations/test_schema.py +++ b/tests/migrations/test_schema.py @@ -10,7 +10,7 @@ def test_financial_institutions_schema_migrate_up_to_045aa502e050( alembic_runner.migrate_up_to("045aa502e050") inspector = sqlalchemy.inspect(alembic_engine) - expexted_columns = [ + expected_columns = [ "lei", "name", "event_time", @@ -35,7 +35,7 @@ def test_financial_institutions_schema_migrate_up_to_045aa502e050( columns = inspector.get_columns("financial_institutions") columns_names = [column.get("name") for column in columns] - assert columns_names == expexted_columns + assert columns_names == expected_columns def test_financial_institutions_schema_migrate_up_to_20e0d51d8be9( @@ -44,7 +44,7 @@ def test_financial_institutions_schema_migrate_up_to_20e0d51d8be9( alembic_runner.migrate_up_to("20e0d51d8be9") inspector = sqlalchemy.inspect(alembic_engine) - expexted_columns = [ + expected_columns = [ "lei", "name", "event_time", @@ -53,4 +53,14 @@ def test_financial_institutions_schema_migrate_up_to_20e0d51d8be9( columns = inspector.get_columns("financial_institutions") columns_names = [column.get("name") for column in columns] - assert columns_names == expexted_columns + assert columns_names == expected_columns + + +def test_fi_types_table_6826f05140cd(alembic_runner: MigrationContext, alembic_engine: Engine): + alembic_runner.migrate_up_to("6826f05140cd") + inspector = sqlalchemy.inspect(alembic_engine) + expected_columns = ["fi_id", "type_id", "details"] + columns = inspector.get_columns("fi_to_type_mapping") + columns_names = [column.get("name") for column in columns] + + assert columns_names == expected_columns