Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/87 add versioning to fi tables #101

Merged
merged 5 commits into from
Feb 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions db_revisions/versions/329c70502325_240131_fi_history_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""240131 fi history table

Revision ID: 329c70502325
Revises: 3f893e52d05c
Create Date: 2024-01-31 10:23:01.081439

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "329c70502325"
down_revision: Union[str, None] = "3f893e52d05c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"financial_institutions_history",
sa.Column("version", sa.Integer(), nullable=False),
sa.Column("lei", sa.String(), nullable=False),
sa.Column("name", sa.String()),
sa.Column("is_active", sa.Boolean()),
sa.Column("tax_id", sa.String(9)),
sa.Column("rssd_id", sa.Integer()),
sa.Column("primary_federal_regulator_id", sa.String(4)),
sa.Column("hmda_institution_type_id", sa.String()),
sa.Column("hq_address_street_1", sa.String()),
sa.Column("hq_address_street_2", sa.String()),
sa.Column("hq_address_city", sa.String()),
sa.Column("hq_address_state_code", sa.String(2)),
sa.Column("hq_address_zip", sa.String(5)),
sa.Column("parent_lei", sa.String(20)),
sa.Column("parent_legal_name", sa.String()),
sa.Column("parent_rssd_id", sa.Integer()),
sa.Column("top_holder_lei", sa.String(20)),
sa.Column("top_holder_legal_name", sa.String()),
sa.Column("top_holder_rssd_id", sa.Integer()),
sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.Column("modified_by", sa.String()),
sa.Column("changeset", sa.JSON),
sa.PrimaryKeyConstraint("lei", "version"),
)


def downgrade() -> None:
op.drop_table("financial_institutions_history")
34 changes: 34 additions & 0 deletions db_revisions/versions/3f893e52d05c_240130_add_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""240130 add version

Revision ID: 3f893e52d05c
Revises: 6826f05140cd
Create Date: 2024-01-30 14:37:47.652233

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "3f893e52d05c"
down_revision: Union[str, None] = "6826f05140cd"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
with op.batch_alter_table("financial_institutions") as batch_op:
batch_op.add_column(sa.Column("version", type_=sa.Integer(), nullable=False, server_default=sa.text("1")))
batch_op.add_column(sa.Column("modified_by", sa.String()))
with op.batch_alter_table("fi_to_type_mapping") as batch_op:
batch_op.add_column(sa.Column("version", type_=sa.Integer(), nullable=False, server_default=sa.text("1")))
batch_op.add_column(sa.Column("modified_by", sa.String()))


def downgrade() -> None:
op.drop_column("financial_institutions", "version")
op.drop_column("financial_institutions", "modified_by")
op.drop_column("fi_to_type_mapping", "version")
op.drop_column("fi_to_type_mapping", "modified_by")
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""240131 fi type association history table

Revision ID: 8106d83ff594
Revises: 329c70502325
Create Date: 2024-01-31 10:23:21.163572

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "8106d83ff594"
down_revision: Union[str, None] = "329c70502325"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"fi_to_type_mapping_history",
sa.Column("version", sa.Integer()),
sa.Column("fi_id", sa.String(), nullable=False),
sa.Column("type_id", sa.String(), nullable=False),
sa.Column("details", sa.String()),
sa.Column("modified_by", sa.String()),
sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.Column("changeset", sa.JSON),
sa.PrimaryKeyConstraint("fi_id", "type_id", "version"),
)


def downgrade() -> None:
op.drop_table("fi_to_type_mapping_history")
57 changes: 57 additions & 0 deletions src/entities/listeners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from sqlalchemy import Connection, Table, event, inspect
from sqlalchemy.orm import Mapper

from .models.dao import Base, FinancialInstitutionDao
from entities.engine.engine import engine


def _setup_fi_history(fi_history: Table, mapping_history: Table):
def _insert_history(
mapper: Mapper[FinancialInstitutionDao], connection: Connection, target: FinancialInstitutionDao
):
new_version = target.version + 1 if target.version else 1
lchen-2101 marked this conversation as resolved.
Show resolved Hide resolved
changes = {}
state = inspect(target)
for attr in state.attrs:
if attr.key == "event_time":
continue
attr_hist = attr.load_history()
if not attr_hist.has_changes():
continue
if attr.key == "sbl_institution_types":
old_types = [o.as_db_dict() for o in attr_hist.deleted]
new_types = [{**n.as_db_dict(), "version": new_version} for n in attr_hist.added]
changes[attr.key] = {"old": old_types, "new": new_types}
else:
changes[attr.key] = {"old": attr_hist.deleted, "new": attr_hist.added}
if changes:
target.version = new_version
for t in target.sbl_institution_types:
t.version = new_version
hist = target.__dict__.copy()
hist.pop("event_time", None)
history_columns = fi_history.columns.keys()
for key in hist.copy():
if key not in history_columns:
del hist[key]
hist["changeset"] = changes
types = [t.as_db_dict() for t in target.sbl_institution_types]
connection.execute(fi_history.insert().values(hist))
connection.execute(mapping_history.insert().values(types))

return _insert_history


async def setup_dao_listeners():
async with engine.begin() as connection:
fi_history, mapping_history = await connection.run_sync(
lambda conn: (
Table("financial_institutions_history", Base.metadata, autoload_with=conn),
Table("fi_to_type_mapping_history", Base.metadata, autoload_with=conn),
)
)

insert_fi_history = _setup_fi_history(fi_history, mapping_history)

event.listen(FinancialInstitutionDao, "before_insert", insert_fi_history)
event.listen(FinancialInstitutionDao, "before_update", insert_fi_history)
lchen-2101 marked this conversation as resolved.
Show resolved Hide resolved
19 changes: 15 additions & 4 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import datetime
from typing import List
from sqlalchemy import ForeignKey, func, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import ForeignKey, func, String, inspect
from sqlalchemy.orm import Mapped, mapped_column, relationship, DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs


@@ -11,19 +10,30 @@ class Base(AsyncAttrs, DeclarativeBase):


class AuditMixin(object):
event_time: Mapped[datetime] = mapped_column(server_default=func.now())
event_time: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now())


class SblTypeMappingDao(Base):
__tablename__ = "fi_to_type_mapping"
version: Mapped[int] = mapped_column(nullable=False, default=0)
__mapper_args__ = {"version_id_col": version, "version_id_generator": False}
lchen-2101 marked this conversation as resolved.
Show resolved Hide resolved
lei: Mapped[str] = mapped_column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True)
type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), primary_key=True)
sbl_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin")
details: Mapped[str] = mapped_column(nullable=True)
modified_by: Mapped[str] = mapped_column()

def as_db_dict(self):
data = {}
for attr, column in inspect(self.__class__).c.items():
data[column.name] = getattr(self, attr)
return data


class FinancialInstitutionDao(AuditMixin, Base):
__tablename__ = "financial_institutions"
version: Mapped[int] = mapped_column(nullable=False, default=0)
__mapper_args__ = {"version_id_col": version, "version_id_generator": False}
lei: Mapped[str] = mapped_column(unique=True, index=True, primary_key=True)
name: Mapped[str] = mapped_column(index=True)
is_active: Mapped[bool] = mapped_column(index=True)
@@ -49,6 +59,7 @@ class FinancialInstitutionDao(AuditMixin, Base):
top_holder_lei: Mapped[str] = mapped_column(String(20), nullable=True)
top_holder_legal_name: Mapped[str] = mapped_column(nullable=True)
top_holder_rssd_id: Mapped[int] = mapped_column(nullable=True)
modified_by: Mapped[str] = mapped_column()


class FinancialInstitutionDomainDao(AuditMixin, Base):
1 change: 1 addition & 0 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@ class FinancialInstitutionDto(FinancialInstitutionBase):
top_holder_lei: str | None = None
top_holder_legal_name: str | None = None
top_holder_rssd_id: int | None = None
version: int | None = None

class Config:
from_attributes = True
13 changes: 9 additions & 4 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession

from regtech_api_commons.models import AuthenticatedUser

from .repo_utils import query_type

from entities.models import (
@@ -68,21 +70,24 @@ async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulator
return await query_type(session, FederalRegulatorDao)


async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao:
async def upsert_institution(
session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser
) -> FinancialInstitutionDao:
async with session.begin():
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)
fi_data.pop("version", None)

if "sbl_institution_types" in fi_data:
types_association = [
SblTypeMappingDao(type_id=t)
SblTypeMappingDao(type_id=t, lei=fi.lei, modified_by=user.id)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details)
else SblTypeMappingDao(type_id=t.id, details=t.details, lei=fi.lei, modified_by=user.id)
for t in fi.sbl_institution_types
]
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data))
db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id))
await session.flush()
await session.refresh(db_fi)
return db_fi
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
from starlette.middleware.authentication import AuthenticationMiddleware
from alembic.config import Config
from alembic import command
from entities.listeners import setup_dao_listeners

from routers import admin_router, institutions_router

@@ -33,6 +34,7 @@ async def lifespan(app_: FastAPI):
log.info("Starting up...")
log.info("run alembic upgrade head...")
run_migrations()
await setup_dao_listeners()
yield
log.info("Shutting down...")

2 changes: 1 addition & 1 deletion src/routers/institutions.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ async def create_institution(
request: Request,
fi: FinancialInstitutionDto,
):
db_fi = await repo.upsert_institution(request.state.db_session, fi)
db_fi = await repo.upsert_institution(request.state.db_session, fi, request.user)
kc_id = oauth2_admin.upsert_group(fi.lei, fi.name)
return kc_id, db_fi

4 changes: 3 additions & 1 deletion tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,7 @@ def auth_mock(mocker: MockerFixture) -> Mock:
@pytest.fixture
def authed_user_mock(auth_mock: Mock) -> Mock:
claims = {
"id": "test_user_id",
"name": "test",
"preferred_username": "test_user",
"email": "test@local.host",
@@ -56,7 +57,7 @@ def unauthed_user_mock(auth_mock: Mock) -> Mock:


@pytest.fixture
def get_institutions_mock(mocker: MockerFixture) -> Mock:
def get_institutions_mock(mocker: MockerFixture, authed_user_mock: Mock) -> Mock:
mock = mocker.patch("entities.repos.institutions_repo.get_institutions")
mock.return_value = [
FinancialInstitutionDao(
@@ -83,6 +84,7 @@ def get_institutions_mock(mocker: MockerFixture) -> Mock:
top_holder_lei="TOPHOLDERLEI123",
top_holder_legal_name="TOP HOLDER LEI 123",
top_holder_rssd_id=123456,
modified_by="test_user_id",
)
]
return mock
27 changes: 20 additions & 7 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
@@ -17,9 +17,12 @@
SblTypeAssociationDto,
)
from entities.repos import institutions_repo as repo
from regtech_api_commons.models import AuthenticatedUser


class TestInstitutionsRepo:
auth_user: AuthenticatedUser = AuthenticatedUser.from_claim({"id": "test_user_id"})

@pytest.fixture(scope="function", autouse=True)
async def setup(
self,
@@ -55,7 +58,7 @@ async def setup(
rssd_id=1234,
primary_federal_regulator_id="FRI1",
hmda_institution_type_id="HIT1",
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1)],
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit1, modified_by="test_user_id")],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
@@ -67,6 +70,7 @@ async def setup(
top_holder_lei="TOPHOLDERLEI123",
top_holder_legal_name="TOP HOLDER LEI 123",
top_holder_rssd_id=123456,
modified_by="test_user_id",
),
FinancialInstitutionDao(
name="Test Bank 456",
@@ -77,7 +81,7 @@ async def setup(
rssd_id=4321,
primary_federal_regulator_id="FRI2",
hmda_institution_type_id="HIT2",
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2)],
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit2, modified_by="test_user_id")],
hq_address_street_1="Test Address Street 2",
hq_address_street_2="",
hq_address_city="Test City 2",
@@ -89,6 +93,7 @@ async def setup(
top_holder_lei="TOPHOLDERLEI456",
top_holder_legal_name="TOP HOLDER LEI 456",
top_holder_rssd_id=654321,
modified_by="test_user_id",
),
FinancialInstitutionDao(
name="Test Sub Bank 456",
@@ -99,7 +104,9 @@ async def setup(
rssd_id=2134,
primary_federal_regulator_id="FRI3",
hmda_institution_type_id="HIT3",
sbl_institution_types=[SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, details="test")],
sbl_institution_types=[
SblTypeMappingDao(sbl_type=sbl_it_dao_sit3, modified_by="test_user_id", details="test")
],
hq_address_street_1="Test Address Street 3",
hq_address_street_2="",
hq_address_city="Test City 3",
@@ -111,6 +118,7 @@ async def setup(
top_holder_lei="TOPHOLDERLEI456",
top_holder_legal_name="TOP HOLDER LEI SUB BANK 456",
top_holder_rssd_id=321654,
modified_by="test_user_id",
),
)

@@ -213,7 +221,9 @@ async def test_add_institution(self, transaction_session: AsyncSession):
top_holder_lei="TOPHOLDERNEWBANKLEI123",
top_holder_legal_name="TOP HOLDER NEW BANK LEI 123",
top_holder_rssd_id=876543,
modified_by="test_user_id",
),
self.auth_user,
)
assert db_fi.domains == []
res = await repo.get_institutions(transaction_session)
@@ -227,7 +237,7 @@ async def test_add_institution_only_required_fields(
):
await repo.upsert_institution(
transaction_session,
FinancialInstitutionDao(
FinancialInstitutionDto(
name="Minimal Bank 123",
lei="MINBANK123",
is_active=True,
@@ -236,6 +246,7 @@ async def test_add_institution_only_required_fields(
hq_address_state_code="FL",
hq_address_zip="22222",
),
self.auth_user,
)
res = await repo.get_institution(query_session, "MINBANK123")
assert res is not None
@@ -247,19 +258,20 @@ async def test_add_institution_missing_required_fields(
with pytest.raises(Exception) as e:
await repo.upsert_institution(
transaction_session,
FinancialInstitutionDao(
FinancialInstitutionDto(
name="Minimal Bank 123",
lei="MINBANK123",
),
self.auth_user,
)
assert "not null constraint failed" in str(e.value).lower()
assert "field required" in str(e.value).lower()
res = await repo.get_institution(query_session, "MINBANK123")
assert res is None

async def test_update_institution(self, transaction_session: AsyncSession):
await repo.upsert_institution(
transaction_session,
FinancialInstitutionDao(
FinancialInstitutionDto(
name="Test Bank 234",
lei="TESTBANK123",
is_active=True,
@@ -268,6 +280,7 @@ async def test_update_institution(self, transaction_session: AsyncSession):
hq_address_state_code="GA",
hq_address_zip="00000",
),
self.auth_user,
)
res = await repo.get_institutions(transaction_session)
assert len(res) == 3
54 changes: 54 additions & 0 deletions tests/entities/test_listeners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from unittest.mock import Mock
from pytest_mock import MockerFixture

from sqlalchemy import Connection, Table
from sqlalchemy.orm import Mapper, InstanceState, AttributeState

from entities.models.dao import FinancialInstitutionDao, SBLInstitutionTypeDao, SblTypeMappingDao

from entities.listeners import _setup_fi_history


class TestListeners:
fi_history: Table = Mock(Table)
mapping_history: Table = Mock(Table)
mapper: Mapper = Mock(Mapper)
connection: Connection = Mock(Connection)
target: FinancialInstitutionDao = FinancialInstitutionDao(
name="Test Bank 123",
lei="TESTBANK123",
is_active=True,
tax_id="123456789",
rssd_id=1234,
primary_federal_regulator_id="FRI1",
hmda_institution_type_id="HIT1",
sbl_institution_types=[SblTypeMappingDao(sbl_type=SBLInstitutionTypeDao(id="SIT1", name="SIT1"))],
hq_address_street_1="Test Address Street 1",
hq_address_street_2="",
hq_address_city="Test City 1",
hq_address_state_code="GA",
hq_address_zip="00000",
parent_lei="PARENTTESTBANK123",
parent_legal_name="PARENT TEST BANK 123",
parent_rssd_id=12345,
top_holder_lei="TOPHOLDERLEI123",
top_holder_legal_name="TOP HOLDER LEI 123",
top_holder_rssd_id=123456,
modified_by="test_user_id",
)

def test_fi_history_listener(self, mocker: MockerFixture):
inspect_mock = mocker.patch("entities.listeners.inspect")
attr_mock1: AttributeState = Mock(AttributeState)
attr_mock1.key = "name"
attr_mock2: AttributeState = Mock(AttributeState)
attr_mock2.key = "event_time"
state_mock: InstanceState = Mock(InstanceState)
state_mock.attrs = [attr_mock1, attr_mock2]
self.fi_history.columns = {"name": "test"}
inspect_mock.return_value = state_mock
fi_listener = _setup_fi_history(self.fi_history, self.mapping_history)
fi_listener(self.mapper, self.connection, self.target)
inspect_mock.assert_called_once_with(self.target)
attr_mock1.load_history.assert_called_once()
self.fi_history.insert.assert_called_once()