From 971e6ba6be6a1ad1efb637d29d4ddccc55663e4c Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:42:29 -0500 Subject: [PATCH] fix: fixed tests, and linters --- .../329c70502325_240131_fi_history_table.py | 8 +++--- .../3f893e52d05c_240130_add_version.py | 6 ++--- ...40131_fi_type_association_history_table.py | 9 ++++--- src/entities/models/dao.py | 3 ++- src/entities/repos/institutions_repo.py | 4 ++- tests/api/conftest.py | 4 ++- .../entities/repos/test_institutions_repo.py | 27 ++++++++++++++----- 7 files changed, 40 insertions(+), 21 deletions(-) diff --git a/db_revisions/versions/329c70502325_240131_fi_history_table.py b/db_revisions/versions/329c70502325_240131_fi_history_table.py index e4f14a9..db9faad 100644 --- a/db_revisions/versions/329c70502325_240131_fi_history_table.py +++ b/db_revisions/versions/329c70502325_240131_fi_history_table.py @@ -12,8 +12,8 @@ # revision identifiers, used by Alembic. -revision: str = '329c70502325' -down_revision: Union[str, None] = '3f893e52d05c' +revision: str = "329c70502325" +down_revision: Union[str, None] = "3f893e52d05c" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -43,9 +43,9 @@ def upgrade() -> None: sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False), sa.Column("modified_by", sa.String()), sa.Column("changeset", sa.JSON), - sa.PrimaryKeyConstraint("lei", "version") + sa.PrimaryKeyConstraint("lei", "version"), ) def downgrade() -> None: - op.drop_table("financial_institutions_history") \ No newline at end of file + op.drop_table("financial_institutions_history") diff --git a/db_revisions/versions/3f893e52d05c_240130_add_version.py b/db_revisions/versions/3f893e52d05c_240130_add_version.py index 88561e9..fbe7521 100644 --- a/db_revisions/versions/3f893e52d05c_240130_add_version.py +++ b/db_revisions/versions/3f893e52d05c_240130_add_version.py @@ -12,8 +12,8 @@ # revision identifiers, used by Alembic. -revision: str = '3f893e52d05c' -down_revision: Union[str, None] = '6826f05140cd' +revision: str = "3f893e52d05c" +down_revision: Union[str, None] = "6826f05140cd" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -31,4 +31,4 @@ def downgrade() -> None: op.drop_column("financial_institutions", "version") op.drop_column("financial_institutions", "modified_by") op.drop_column("fi_to_type_mapping", "version") - op.drop_column("fi_to_type_mapping", "modified_by") \ No newline at end of file + op.drop_column("fi_to_type_mapping", "modified_by") diff --git a/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py b/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py index 6be4c99..a3753b7 100644 --- a/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py +++ b/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py @@ -12,11 +12,12 @@ # revision identifiers, used by Alembic. -revision: str = '8106d83ff594' -down_revision: Union[str, None] = '329c70502325' +revision: str = "8106d83ff594" +down_revision: Union[str, None] = "329c70502325" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None + def upgrade() -> None: op.create_table( "fi_to_type_mapping_history", @@ -27,9 +28,9 @@ def upgrade() -> None: sa.Column("modified_by", sa.String()), sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False), sa.Column("changeset", sa.JSON), - sa.PrimaryKeyConstraint("fi_id", "type_id", "version") + sa.PrimaryKeyConstraint("fi_id", "type_id", "version"), ) def downgrade() -> None: - op.drop_table("fi_to_type_mapping_history") \ No newline at end of file + op.drop_table("fi_to_type_mapping_history") diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 5c080ce..464a825 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -20,8 +20,9 @@ class SblTypeMappingDao(Base): lei: Mapped[str] = mapped_column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True) type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), primary_key=True) sbl_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin") - details: Mapped[str] = mapped_column() + details: Mapped[str] = mapped_column(nullable=True) modified_by: Mapped[str] = mapped_column() + def as_db_dict(self): data = {} for attr, column in inspect(self.__class__).c.items(): diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index c14e95b..48b3e36 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -71,7 +71,9 @@ async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulator return await query_type(session, FederalRegulatorDao) -async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser) -> FinancialInstitutionDao: +async def upsert_institution( + session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser +) -> FinancialInstitutionDao: async with session.begin(): fi_data = fi.__dict__.copy() fi_data.pop("_sa_instance_state", None) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 0291c51..6de83c8 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -37,6 +37,7 @@ def auth_mock(mocker: MockerFixture) -> Mock: @pytest.fixture def authed_user_mock(auth_mock: Mock) -> Mock: claims = { + "id": "test_user_id", "name": "test", "preferred_username": "test_user", "email": "test@local.host", @@ -56,7 +57,7 @@ def unauthed_user_mock(auth_mock: Mock) -> Mock: @pytest.fixture -def get_institutions_mock(mocker: MockerFixture) -> Mock: +def get_institutions_mock(mocker: MockerFixture, authed_user_mock: Mock) -> Mock: mock = mocker.patch("entities.repos.institutions_repo.get_institutions") mock.return_value = [ FinancialInstitutionDao( @@ -83,6 +84,7 @@ def get_institutions_mock(mocker: MockerFixture) -> Mock: top_holder_lei="TOPHOLDERLEI123", top_holder_legal_name="TOP HOLDER LEI 123", top_holder_rssd_id=123456, + modified_by="test_user_id", ) ] return mock diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index ce18766..f764b45 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -17,9 +17,12 @@ SblTypeAssociationDto, ) from entities.repos import institutions_repo as repo +from regtech_api_commons.models import AuthenticatedUser class TestInstitutionsRepo: + auth_user: AuthenticatedUser = AuthenticatedUser.from_claim({"id": "test_user_id"}) + @pytest.fixture(scope="function", autouse=True) async def setup( self, @@ -55,7 +58,7 @@ async def setup( rssd_id=1234, primary_federal_regulator_id="FRI1", hmda_institution_type_id="HIT1", - sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1)], + sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1, modified_by="test_user_id")], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -67,6 +70,7 @@ async def setup( top_holder_lei="TOPHOLDERLEI123", top_holder_legal_name="TOP HOLDER LEI 123", top_holder_rssd_id=123456, + modified_by="test_user_id", ), FinancialInstitutionDao( name="Test Bank 456", @@ -77,7 +81,7 @@ async def setup( rssd_id=4321, primary_federal_regulator_id="FRI2", hmda_institution_type_id="HIT2", - sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2)], + sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2, modified_by="test_user_id")], hq_address_street_1="Test Address Street 2", hq_address_street_2="", hq_address_city="Test City 2", @@ -89,6 +93,7 @@ async def setup( top_holder_lei="TOPHOLDERLEI456", top_holder_legal_name="TOP HOLDER LEI 456", top_holder_rssd_id=654321, + modified_by="test_user_id", ), FinancialInstitutionDao( name="Test Sub Bank 456", @@ -99,7 +104,9 @@ async def setup( rssd_id=2134, primary_federal_regulator_id="FRI3", hmda_institution_type_id="HIT3", - sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, details="test")], + sbl_institution_types=[ + SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, modified_by="test_user_id", details="test") + ], hq_address_street_1="Test Address Street 3", hq_address_street_2="", hq_address_city="Test City 3", @@ -111,6 +118,7 @@ async def setup( top_holder_lei="TOPHOLDERLEI456", top_holder_legal_name="TOP HOLDER LEI SUB BANK 456", top_holder_rssd_id=321654, + modified_by="test_user_id", ), ) @@ -213,7 +221,9 @@ async def test_add_institution(self, transaction_session: AsyncSession): top_holder_lei="TOPHOLDERNEWBANKLEI123", top_holder_legal_name="TOP HOLDER NEW BANK LEI 123", top_holder_rssd_id=876543, + modified_by="test_user_id", ), + self.auth_user, ) assert db_fi.domains == [] res = await repo.get_institutions(transaction_session) @@ -227,7 +237,7 @@ async def test_add_institution_only_required_fields( ): await repo.upsert_institution( transaction_session, - FinancialInstitutionDao( + FinancialInstitutionDto( name="Minimal Bank 123", lei="MINBANK123", is_active=True, @@ -236,6 +246,7 @@ async def test_add_institution_only_required_fields( hq_address_state_code="FL", hq_address_zip="22222", ), + self.auth_user, ) res = await repo.get_institution(query_session, "MINBANK123") assert res is not None @@ -247,19 +258,20 @@ async def test_add_institution_missing_required_fields( with pytest.raises(Exception) as e: await repo.upsert_institution( transaction_session, - FinancialInstitutionDao( + FinancialInstitutionDto( name="Minimal Bank 123", lei="MINBANK123", ), + self.auth_user, ) - assert "not null constraint failed" in str(e.value).lower() + assert "field required" in str(e.value).lower() res = await repo.get_institution(query_session, "MINBANK123") assert res is None async def test_update_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( transaction_session, - FinancialInstitutionDao( + FinancialInstitutionDto( name="Test Bank 234", lei="TESTBANK123", is_active=True, @@ -268,6 +280,7 @@ async def test_update_institution(self, transaction_session: AsyncSession): hq_address_state_code="GA", hq_address_zip="00000", ), + self.auth_user, ) res = await repo.get_institutions(transaction_session) assert len(res) == 3