From 3e0320a2c6965dc47cb6f9f8dbc422370b4e0055 Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Fri, 2 Feb 2024 18:11:54 -0500 Subject: [PATCH 1/5] feat: add fi tables versioning --- .../329c70502325_240131_fi_history_table.py | 51 +++++++++++++++++++ .../3f893e52d05c_240130_add_version.py | 34 +++++++++++++ ...40131_fi_type_association_history_table.py | 35 +++++++++++++ src/entities/listeners.py | 51 +++++++++++++++++++ src/entities/models/dao.py | 18 +++++-- src/entities/repos/institutions_repo.py | 11 ++-- src/main.py | 2 + src/routers/institutions.py | 2 +- 8 files changed, 195 insertions(+), 9 deletions(-) create mode 100644 db_revisions/versions/329c70502325_240131_fi_history_table.py create mode 100644 db_revisions/versions/3f893e52d05c_240130_add_version.py create mode 100644 db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py create mode 100644 src/entities/listeners.py diff --git a/db_revisions/versions/329c70502325_240131_fi_history_table.py b/db_revisions/versions/329c70502325_240131_fi_history_table.py new file mode 100644 index 0000000..e4f14a9 --- /dev/null +++ b/db_revisions/versions/329c70502325_240131_fi_history_table.py @@ -0,0 +1,51 @@ +"""240131 fi history table + +Revision ID: 329c70502325 +Revises: 3f893e52d05c +Create Date: 2024-01-31 10:23:01.081439 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '329c70502325' +down_revision: Union[str, None] = '3f893e52d05c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "financial_institutions_history", + sa.Column("version", sa.Integer(), nullable=False), + sa.Column("lei", sa.String(), nullable=False), + sa.Column("name", sa.String()), + sa.Column("is_active", sa.Boolean()), + sa.Column("tax_id", sa.String(9)), + sa.Column("rssd_id", sa.Integer()), + sa.Column("primary_federal_regulator_id", sa.String(4)), + sa.Column("hmda_institution_type_id", sa.String()), + sa.Column("hq_address_street_1", sa.String()), + sa.Column("hq_address_street_2", sa.String()), + sa.Column("hq_address_city", sa.String()), + sa.Column("hq_address_state_code", sa.String(2)), + sa.Column("hq_address_zip", sa.String(5)), + sa.Column("parent_lei", sa.String(20)), + sa.Column("parent_legal_name", sa.String()), + sa.Column("parent_rssd_id", sa.Integer()), + sa.Column("top_holder_lei", sa.String(20)), + sa.Column("top_holder_legal_name", sa.String()), + sa.Column("top_holder_rssd_id", sa.Integer()), + sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False), + sa.Column("modified_by", sa.String()), + sa.Column("changeset", sa.JSON), + sa.PrimaryKeyConstraint("lei", "version") + ) + + +def downgrade() -> None: + op.drop_table("financial_institutions_history") \ No newline at end of file diff --git a/db_revisions/versions/3f893e52d05c_240130_add_version.py b/db_revisions/versions/3f893e52d05c_240130_add_version.py new file mode 100644 index 0000000..88561e9 --- /dev/null +++ b/db_revisions/versions/3f893e52d05c_240130_add_version.py @@ -0,0 +1,34 @@ +"""240130 add version + +Revision ID: 3f893e52d05c +Revises: 6826f05140cd +Create Date: 2024-01-30 14:37:47.652233 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '3f893e52d05c' +down_revision: Union[str, None] = '6826f05140cd' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + with op.batch_alter_table("financial_institutions") as batch_op: + batch_op.add_column(sa.Column("version", type_=sa.Integer(), nullable=False, server_default=sa.text("1"))) + batch_op.add_column(sa.Column("modified_by", sa.String())) + with op.batch_alter_table("fi_to_type_mapping") as batch_op: + batch_op.add_column(sa.Column("version", type_=sa.Integer(), nullable=False, server_default=sa.text("1"))) + batch_op.add_column(sa.Column("modified_by", sa.String())) + + +def downgrade() -> None: + op.drop_column("financial_institutions", "version") + op.drop_column("financial_institutions", "modified_by") + op.drop_column("fi_to_type_mapping", "version") + op.drop_column("fi_to_type_mapping", "modified_by") \ No newline at end of file diff --git a/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py b/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py new file mode 100644 index 0000000..6be4c99 --- /dev/null +++ b/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py @@ -0,0 +1,35 @@ +"""240131 fi type association history table + +Revision ID: 8106d83ff594 +Revises: 329c70502325 +Create Date: 2024-01-31 10:23:21.163572 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '8106d83ff594' +down_revision: Union[str, None] = '329c70502325' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +def upgrade() -> None: + op.create_table( + "fi_to_type_mapping_history", + sa.Column("version", sa.Integer()), + sa.Column("fi_id", sa.String(), nullable=False), + sa.Column("type_id", sa.String(), nullable=False), + sa.Column("details", sa.String()), + sa.Column("modified_by", sa.String()), + sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False), + sa.Column("changeset", sa.JSON), + sa.PrimaryKeyConstraint("fi_id", "type_id", "version") + ) + + +def downgrade() -> None: + op.drop_table("fi_to_type_mapping_history") \ No newline at end of file diff --git a/src/entities/listeners.py b/src/entities/listeners.py new file mode 100644 index 0000000..c7d3b25 --- /dev/null +++ b/src/entities/listeners.py @@ -0,0 +1,51 @@ +from sqlalchemy import Connection, Table, event, inspect +from sqlalchemy.orm import Mapper + +from .models.dao import Base, FinancialInstitutionDao +from entities.engine.engine import engine + + +async def setup_fi_dao_listeners(): + async with engine.begin() as connection: + fi_history, mapping_history = await connection.run_sync( + lambda conn: ( + Table("financial_institutions_history", Base.metadata, autoload_with=conn), + Table("fi_to_type_mapping_history", Base.metadata, autoload_with=conn), + ) + ) + + def insert_history( + mapper: Mapper[FinancialInstitutionDao], connection: Connection, target: FinancialInstitutionDao + ): + new_version = target.version + 1 if target.version else 1 + changes = {} + state = inspect(target) + for attr in state.attrs: + if attr.key == "event_time": + continue + attr_hist = attr.load_history() + if not attr_hist.has_changes(): + continue + if attr.key == "sbl_institution_types": + old_types = [o.as_dict() for o in attr_hist.deleted] + new_types = [{**n.as_dict(), "version": new_version} for n in attr_hist.added] + changes[attr.key] = {"old": old_types, "new": new_types} + else: + changes[attr.key] = {"old": attr_hist.deleted, "new": attr_hist.added} + if changes: + target.version = new_version + for t in target.sbl_institution_types: + t.version = new_version + hist = target.__dict__.copy() + hist.pop("event_time", None) + history_columns = fi_history.columns.keys() + for key in hist.copy(): + if key not in history_columns: + del hist[key] + hist["changeset"] = changes + types = [t.as_db_dict() for t in target.sbl_institution_types] + connection.execute(fi_history.insert().values(hist)) + connection.execute(mapping_history.insert().values(types)) + + event.listen(FinancialInstitutionDao, "before_insert", insert_history) + event.listen(FinancialInstitutionDao, "before_update", insert_history) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 031e687..5c080ce 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -1,8 +1,7 @@ from datetime import datetime from typing import List -from sqlalchemy import ForeignKey, func, String -from sqlalchemy.orm import Mapped, mapped_column, relationship -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import ForeignKey, func, String, inspect +from sqlalchemy.orm import Mapped, mapped_column, relationship, DeclarativeBase from sqlalchemy.ext.asyncio import AsyncAttrs @@ -16,14 +15,24 @@ class AuditMixin(object): class SblTypeMappingDao(Base): __tablename__ = "fi_to_type_mapping" + version: Mapped[int] = mapped_column(nullable=False, default=0) + __mapper_args__ = {"version_id_col": version, "version_id_generator": False} lei: Mapped[str] = mapped_column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True) type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), primary_key=True) sbl_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin") - details: Mapped[str] = mapped_column(nullable=True) + details: Mapped[str] = mapped_column() + modified_by: Mapped[str] = mapped_column() + def as_db_dict(self): + data = {} + for attr, column in inspect(self.__class__).c.items(): + data[column.name] = getattr(self, attr) + return data class FinancialInstitutionDao(AuditMixin, Base): __tablename__ = "financial_institutions" + version: Mapped[int] = mapped_column(nullable=False, default=0) + __mapper_args__ = {"version_id_col": version, "version_id_generator": False} lei: Mapped[str] = mapped_column(unique=True, index=True, primary_key=True) name: Mapped[str] = mapped_column(index=True) is_active: Mapped[bool] = mapped_column(index=True) @@ -49,6 +58,7 @@ class FinancialInstitutionDao(AuditMixin, Base): top_holder_lei: Mapped[str] = mapped_column(String(20), nullable=True) top_holder_legal_name: Mapped[str] = mapped_column(nullable=True) top_holder_rssd_id: Mapped[int] = mapped_column(nullable=True) + modified_by: Mapped[str] = mapped_column() class FinancialInstitutionDomainDao(AuditMixin, Base): diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 6cde0c8..c14e95b 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -1,9 +1,12 @@ +from datetime import datetime from typing import List from sqlalchemy import select, func from sqlalchemy.orm import joinedload from sqlalchemy.ext.asyncio import AsyncSession +from regtech_api_commons.models import AuthenticatedUser + from .repo_utils import query_type from entities.models import ( @@ -68,21 +71,21 @@ async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulator return await query_type(session, FederalRegulatorDao) -async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao: +async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser) -> FinancialInstitutionDao: async with session.begin(): fi_data = fi.__dict__.copy() fi_data.pop("_sa_instance_state", None) if "sbl_institution_types" in fi_data: types_association = [ - SblTypeMappingDao(type_id=t) + SblTypeMappingDao(type_id=t, lei=fi.lei, modified_by=user.id) if isinstance(t, str) - else SblTypeMappingDao(type_id=t.id, details=t.details) + else SblTypeMappingDao(type_id=t.id, details=t.details, lei=fi.lei, modified_by=user.id) for t in fi.sbl_institution_types ] fi_data["sbl_institution_types"] = types_association - db_fi = await session.merge(FinancialInstitutionDao(**fi_data)) + db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id, event_time=datetime.now())) await session.flush() await session.refresh(db_fi) return db_fi diff --git a/src/main.py b/src/main.py index 2dffd0e..a6368a1 100644 --- a/src/main.py +++ b/src/main.py @@ -9,6 +9,7 @@ from starlette.middleware.authentication import AuthenticationMiddleware from alembic.config import Config from alembic import command +from entities.listeners import setup_fi_dao_listeners from routers import admin_router, institutions_router @@ -33,6 +34,7 @@ async def lifespan(app_: FastAPI): log.info("Starting up...") log.info("run alembic upgrade head...") run_migrations() + await setup_fi_dao_listeners() yield log.info("Shutting down...") diff --git a/src/routers/institutions.py b/src/routers/institutions.py index 5083683..49556bf 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -51,7 +51,7 @@ async def create_institution( request: Request, fi: FinancialInstitutionDto, ): - db_fi = await repo.upsert_institution(request.state.db_session, fi) + db_fi = await repo.upsert_institution(request.state.db_session, fi, request.user) kc_id = oauth2_admin.upsert_group(fi.lei, fi.name) return kc_id, db_fi From 4b22a3da5fd9bf7bb56fc5c689e2080909d66628 Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Fri, 2 Feb 2024 18:27:01 -0500 Subject: [PATCH 2/5] fix: method name updated --- src/entities/listeners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/entities/listeners.py b/src/entities/listeners.py index c7d3b25..6b03b76 100644 --- a/src/entities/listeners.py +++ b/src/entities/listeners.py @@ -27,8 +27,8 @@ def insert_history( if not attr_hist.has_changes(): continue if attr.key == "sbl_institution_types": - old_types = [o.as_dict() for o in attr_hist.deleted] - new_types = [{**n.as_dict(), "version": new_version} for n in attr_hist.added] + old_types = [o.as_db_dict() for o in attr_hist.deleted] + new_types = [{**n.as_db_dict(), "version": new_version} for n in attr_hist.added] changes[attr.key] = {"old": old_types, "new": new_types} else: changes[attr.key] = {"old": attr_hist.deleted, "new": attr_hist.added} From 971e6ba6be6a1ad1efb637d29d4ddccc55663e4c Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:42:29 -0500 Subject: [PATCH 3/5] fix: fixed tests, and linters --- .../329c70502325_240131_fi_history_table.py | 8 +++--- .../3f893e52d05c_240130_add_version.py | 6 ++--- ...40131_fi_type_association_history_table.py | 9 ++++--- src/entities/models/dao.py | 3 ++- src/entities/repos/institutions_repo.py | 4 ++- tests/api/conftest.py | 4 ++- .../entities/repos/test_institutions_repo.py | 27 ++++++++++++++----- 7 files changed, 40 insertions(+), 21 deletions(-) diff --git a/db_revisions/versions/329c70502325_240131_fi_history_table.py b/db_revisions/versions/329c70502325_240131_fi_history_table.py index e4f14a9..db9faad 100644 --- a/db_revisions/versions/329c70502325_240131_fi_history_table.py +++ b/db_revisions/versions/329c70502325_240131_fi_history_table.py @@ -12,8 +12,8 @@ # revision identifiers, used by Alembic. -revision: str = '329c70502325' -down_revision: Union[str, None] = '3f893e52d05c' +revision: str = "329c70502325" +down_revision: Union[str, None] = "3f893e52d05c" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -43,9 +43,9 @@ def upgrade() -> None: sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False), sa.Column("modified_by", sa.String()), sa.Column("changeset", sa.JSON), - sa.PrimaryKeyConstraint("lei", "version") + sa.PrimaryKeyConstraint("lei", "version"), ) def downgrade() -> None: - op.drop_table("financial_institutions_history") \ No newline at end of file + op.drop_table("financial_institutions_history") diff --git a/db_revisions/versions/3f893e52d05c_240130_add_version.py b/db_revisions/versions/3f893e52d05c_240130_add_version.py index 88561e9..fbe7521 100644 --- a/db_revisions/versions/3f893e52d05c_240130_add_version.py +++ b/db_revisions/versions/3f893e52d05c_240130_add_version.py @@ -12,8 +12,8 @@ # revision identifiers, used by Alembic. -revision: str = '3f893e52d05c' -down_revision: Union[str, None] = '6826f05140cd' +revision: str = "3f893e52d05c" +down_revision: Union[str, None] = "6826f05140cd" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -31,4 +31,4 @@ def downgrade() -> None: op.drop_column("financial_institutions", "version") op.drop_column("financial_institutions", "modified_by") op.drop_column("fi_to_type_mapping", "version") - op.drop_column("fi_to_type_mapping", "modified_by") \ No newline at end of file + op.drop_column("fi_to_type_mapping", "modified_by") diff --git a/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py b/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py index 6be4c99..a3753b7 100644 --- a/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py +++ b/db_revisions/versions/8106d83ff594_240131_fi_type_association_history_table.py @@ -12,11 +12,12 @@ # revision identifiers, used by Alembic. -revision: str = '8106d83ff594' -down_revision: Union[str, None] = '329c70502325' +revision: str = "8106d83ff594" +down_revision: Union[str, None] = "329c70502325" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None + def upgrade() -> None: op.create_table( "fi_to_type_mapping_history", @@ -27,9 +28,9 @@ def upgrade() -> None: sa.Column("modified_by", sa.String()), sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False), sa.Column("changeset", sa.JSON), - sa.PrimaryKeyConstraint("fi_id", "type_id", "version") + sa.PrimaryKeyConstraint("fi_id", "type_id", "version"), ) def downgrade() -> None: - op.drop_table("fi_to_type_mapping_history") \ No newline at end of file + op.drop_table("fi_to_type_mapping_history") diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 5c080ce..464a825 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -20,8 +20,9 @@ class SblTypeMappingDao(Base): lei: Mapped[str] = mapped_column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True) type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), primary_key=True) sbl_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin") - details: Mapped[str] = mapped_column() + details: Mapped[str] = mapped_column(nullable=True) modified_by: Mapped[str] = mapped_column() + def as_db_dict(self): data = {} for attr, column in inspect(self.__class__).c.items(): diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index c14e95b..48b3e36 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -71,7 +71,9 @@ async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulator return await query_type(session, FederalRegulatorDao) -async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser) -> FinancialInstitutionDao: +async def upsert_institution( + session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser +) -> FinancialInstitutionDao: async with session.begin(): fi_data = fi.__dict__.copy() fi_data.pop("_sa_instance_state", None) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 0291c51..6de83c8 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -37,6 +37,7 @@ def auth_mock(mocker: MockerFixture) -> Mock: @pytest.fixture def authed_user_mock(auth_mock: Mock) -> Mock: claims = { + "id": "test_user_id", "name": "test", "preferred_username": "test_user", "email": "test@local.host", @@ -56,7 +57,7 @@ def unauthed_user_mock(auth_mock: Mock) -> Mock: @pytest.fixture -def get_institutions_mock(mocker: MockerFixture) -> Mock: +def get_institutions_mock(mocker: MockerFixture, authed_user_mock: Mock) -> Mock: mock = mocker.patch("entities.repos.institutions_repo.get_institutions") mock.return_value = [ FinancialInstitutionDao( @@ -83,6 +84,7 @@ def get_institutions_mock(mocker: MockerFixture) -> Mock: top_holder_lei="TOPHOLDERLEI123", top_holder_legal_name="TOP HOLDER LEI 123", top_holder_rssd_id=123456, + modified_by="test_user_id", ) ] return mock diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index ce18766..f764b45 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -17,9 +17,12 @@ SblTypeAssociationDto, ) from entities.repos import institutions_repo as repo +from regtech_api_commons.models import AuthenticatedUser class TestInstitutionsRepo: + auth_user: AuthenticatedUser = AuthenticatedUser.from_claim({"id": "test_user_id"}) + @pytest.fixture(scope="function", autouse=True) async def setup( self, @@ -55,7 +58,7 @@ async def setup( rssd_id=1234, primary_federal_regulator_id="FRI1", hmda_institution_type_id="HIT1", - sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1)], + sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1, modified_by="test_user_id")], hq_address_street_1="Test Address Street 1", hq_address_street_2="", hq_address_city="Test City 1", @@ -67,6 +70,7 @@ async def setup( top_holder_lei="TOPHOLDERLEI123", top_holder_legal_name="TOP HOLDER LEI 123", top_holder_rssd_id=123456, + modified_by="test_user_id", ), FinancialInstitutionDao( name="Test Bank 456", @@ -77,7 +81,7 @@ async def setup( rssd_id=4321, primary_federal_regulator_id="FRI2", hmda_institution_type_id="HIT2", - sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2)], + sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2, modified_by="test_user_id")], hq_address_street_1="Test Address Street 2", hq_address_street_2="", hq_address_city="Test City 2", @@ -89,6 +93,7 @@ async def setup( top_holder_lei="TOPHOLDERLEI456", top_holder_legal_name="TOP HOLDER LEI 456", top_holder_rssd_id=654321, + modified_by="test_user_id", ), FinancialInstitutionDao( name="Test Sub Bank 456", @@ -99,7 +104,9 @@ async def setup( rssd_id=2134, primary_federal_regulator_id="FRI3", hmda_institution_type_id="HIT3", - sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, details="test")], + sbl_institution_types=[ + SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, modified_by="test_user_id", details="test") + ], hq_address_street_1="Test Address Street 3", hq_address_street_2="", hq_address_city="Test City 3", @@ -111,6 +118,7 @@ async def setup( top_holder_lei="TOPHOLDERLEI456", top_holder_legal_name="TOP HOLDER LEI SUB BANK 456", top_holder_rssd_id=321654, + modified_by="test_user_id", ), ) @@ -213,7 +221,9 @@ async def test_add_institution(self, transaction_session: AsyncSession): top_holder_lei="TOPHOLDERNEWBANKLEI123", top_holder_legal_name="TOP HOLDER NEW BANK LEI 123", top_holder_rssd_id=876543, + modified_by="test_user_id", ), + self.auth_user, ) assert db_fi.domains == [] res = await repo.get_institutions(transaction_session) @@ -227,7 +237,7 @@ async def test_add_institution_only_required_fields( ): await repo.upsert_institution( transaction_session, - FinancialInstitutionDao( + FinancialInstitutionDto( name="Minimal Bank 123", lei="MINBANK123", is_active=True, @@ -236,6 +246,7 @@ async def test_add_institution_only_required_fields( hq_address_state_code="FL", hq_address_zip="22222", ), + self.auth_user, ) res = await repo.get_institution(query_session, "MINBANK123") assert res is not None @@ -247,19 +258,20 @@ async def test_add_institution_missing_required_fields( with pytest.raises(Exception) as e: await repo.upsert_institution( transaction_session, - FinancialInstitutionDao( + FinancialInstitutionDto( name="Minimal Bank 123", lei="MINBANK123", ), + self.auth_user, ) - assert "not null constraint failed" in str(e.value).lower() + assert "field required" in str(e.value).lower() res = await repo.get_institution(query_session, "MINBANK123") assert res is None async def test_update_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( transaction_session, - FinancialInstitutionDao( + FinancialInstitutionDto( name="Test Bank 234", lei="TESTBANK123", is_active=True, @@ -268,6 +280,7 @@ async def test_update_institution(self, transaction_session: AsyncSession): hq_address_state_code="GA", hq_address_zip="00000", ), + self.auth_user, ) res = await repo.get_institutions(transaction_session) assert len(res) == 3 From e9404a04d321072ee6d854ccbb8d6d2b6e1f51c4 Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Wed, 7 Feb 2024 11:26:57 -0500 Subject: [PATCH 4/5] feat: added basic test for listeners --- src/entities/listeners.py | 30 ++++++++------ src/entities/models/dto.py | 1 + src/entities/repos/institutions_repo.py | 1 + src/main.py | 4 +- tests/entities/test_listeners.py | 54 +++++++++++++++++++++++++ 5 files changed, 76 insertions(+), 14 deletions(-) create mode 100644 tests/entities/test_listeners.py diff --git a/src/entities/listeners.py b/src/entities/listeners.py index 6b03b76..bfb6be0 100644 --- a/src/entities/listeners.py +++ b/src/entities/listeners.py @@ -5,16 +5,8 @@ from entities.engine.engine import engine -async def setup_fi_dao_listeners(): - async with engine.begin() as connection: - fi_history, mapping_history = await connection.run_sync( - lambda conn: ( - Table("financial_institutions_history", Base.metadata, autoload_with=conn), - Table("fi_to_type_mapping_history", Base.metadata, autoload_with=conn), - ) - ) - - def insert_history( +def _setup_fi_history(fi_history: Table, mapping_history: Table): + def _insert_history( mapper: Mapper[FinancialInstitutionDao], connection: Connection, target: FinancialInstitutionDao ): new_version = target.version + 1 if target.version else 1 @@ -47,5 +39,19 @@ def insert_history( connection.execute(fi_history.insert().values(hist)) connection.execute(mapping_history.insert().values(types)) - event.listen(FinancialInstitutionDao, "before_insert", insert_history) - event.listen(FinancialInstitutionDao, "before_update", insert_history) + return _insert_history + + +async def setup_dao_listeners(): + async with engine.begin() as connection: + fi_history, mapping_history = await connection.run_sync( + lambda conn: ( + Table("financial_institutions_history", Base.metadata, autoload_with=conn), + Table("fi_to_type_mapping_history", Base.metadata, autoload_with=conn), + ) + ) + + insert_fi_history = _setup_fi_history(fi_history, mapping_history) + + event.listen(FinancialInstitutionDao, "before_insert", insert_fi_history) + event.listen(FinancialInstitutionDao, "before_update", insert_fi_history) diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index 90e98da..c7f312f 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -62,6 +62,7 @@ class FinancialInstitutionDto(FinancialInstitutionBase): top_holder_lei: str | None = None top_holder_legal_name: str | None = None top_holder_rssd_id: int | None = None + version: int | None = None class Config: from_attributes = True diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 48b3e36..457e533 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -77,6 +77,7 @@ async def upsert_institution( async with session.begin(): fi_data = fi.__dict__.copy() fi_data.pop("_sa_instance_state", None) + fi_data.pop("version", None) if "sbl_institution_types" in fi_data: types_association = [ diff --git a/src/main.py b/src/main.py index a6368a1..abe88bf 100644 --- a/src/main.py +++ b/src/main.py @@ -9,7 +9,7 @@ from starlette.middleware.authentication import AuthenticationMiddleware from alembic.config import Config from alembic import command -from entities.listeners import setup_fi_dao_listeners +from entities.listeners import setup_dao_listeners from routers import admin_router, institutions_router @@ -34,7 +34,7 @@ async def lifespan(app_: FastAPI): log.info("Starting up...") log.info("run alembic upgrade head...") run_migrations() - await setup_fi_dao_listeners() + await setup_dao_listeners() yield log.info("Shutting down...") diff --git a/tests/entities/test_listeners.py b/tests/entities/test_listeners.py new file mode 100644 index 0000000..88dc853 --- /dev/null +++ b/tests/entities/test_listeners.py @@ -0,0 +1,54 @@ +from unittest.mock import Mock +from pytest_mock import MockerFixture + +from sqlalchemy import Connection, Table +from sqlalchemy.orm import Mapper, InstanceState, AttributeState + +from entities.models.dao import FinancialInstitutionDao, SBLInstitutionTypeDao, SblTypeMappingDao + +from entities.listeners import _setup_fi_history + + +class TestListeners: + fi_history: Table = Mock(Table) + mapping_history: Table = Mock(Table) + mapper: Mapper = Mock(Mapper) + connection: Connection = Mock(Connection) + target: FinancialInstitutionDao = FinancialInstitutionDao( + name="Test Bank 123", + lei="TESTBANK123", + is_active=True, + tax_id="123456789", + rssd_id=1234, + primary_federal_regulator_id="FRI1", + hmda_institution_type_id="HIT1", + sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))], + hq_address_street_1="Test Address Street 1", + hq_address_street_2="", + hq_address_city="Test City 1", + hq_address_state_code="GA", + hq_address_zip="00000", + parent_lei="PARENTTESTBANK123", + parent_legal_name="PARENT TEST BANK 123", + parent_rssd_id=12345, + top_holder_lei="TOPHOLDERLEI123", + top_holder_legal_name="TOP HOLDER LEI 123", + top_holder_rssd_id=123456, + modified_by="test_user_id", + ) + + def test_fi_history_listener(self, mocker: MockerFixture): + inspect_mock = mocker.patch("entities.listeners.inspect") + attr_mock1: AttributeState = Mock(AttributeState) + attr_mock1.key = "name" + attr_mock2: AttributeState = Mock(AttributeState) + attr_mock2.key = "event_time" + state_mock: InstanceState = Mock(InstanceState) + state_mock.attrs = [attr_mock1, attr_mock2] + self.fi_history.columns = {"name": "test"} + inspect_mock.return_value = state_mock + fi_listener = _setup_fi_history(self.fi_history, self.mapping_history) + fi_listener(self.mapper, self.connection, self.target) + inspect_mock.assert_called_once_with(self.target) + attr_mock1.load_history.assert_called_once() + self.fi_history.insert.assert_called_once() From 4490e0b8e8325d2ce21548798491804c64f1ab5b Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Wed, 7 Feb 2024 13:10:43 -0500 Subject: [PATCH 5/5] refactor: set event_time to onupdate in dao instead of explicitly specifying --- src/entities/models/dao.py | 2 +- src/entities/repos/institutions_repo.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 464a825..0eb3315 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -10,7 +10,7 @@ class Base(AsyncAttrs, DeclarativeBase): class AuditMixin(object): - event_time: Mapped[datetime] = mapped_column(server_default=func.now()) + event_time: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now()) class SblTypeMappingDao(Base): diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 457e533..b171a71 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -1,4 +1,3 @@ -from datetime import datetime from typing import List from sqlalchemy import select, func @@ -88,7 +87,7 @@ async def upsert_institution( ] fi_data["sbl_institution_types"] = types_association - db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id, event_time=datetime.now())) + db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id)) await session.flush() await session.refresh(db_fi) return db_fi