Skip to content

Commit

Permalink
feat: add fi tables versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Feb 2, 2024
1 parent aa0c23b commit 3e0320a
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 9 deletions.
51 changes: 51 additions & 0 deletions db_revisions/versions/329c70502325_240131_fi_history_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""240131 fi history table
Revision ID: 329c70502325
Revises: 3f893e52d05c
Create Date: 2024-01-31 10:23:01.081439
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '329c70502325'
down_revision: Union[str, None] = '3f893e52d05c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"financial_institutions_history",
sa.Column("version", sa.Integer(), nullable=False),
sa.Column("lei", sa.String(), nullable=False),
sa.Column("name", sa.String()),
sa.Column("is_active", sa.Boolean()),
sa.Column("tax_id", sa.String(9)),
sa.Column("rssd_id", sa.Integer()),
sa.Column("primary_federal_regulator_id", sa.String(4)),
sa.Column("hmda_institution_type_id", sa.String()),
sa.Column("hq_address_street_1", sa.String()),
sa.Column("hq_address_street_2", sa.String()),
sa.Column("hq_address_city", sa.String()),
sa.Column("hq_address_state_code", sa.String(2)),
sa.Column("hq_address_zip", sa.String(5)),
sa.Column("parent_lei", sa.String(20)),
sa.Column("parent_legal_name", sa.String()),
sa.Column("parent_rssd_id", sa.Integer()),
sa.Column("top_holder_lei", sa.String(20)),
sa.Column("top_holder_legal_name", sa.String()),
sa.Column("top_holder_rssd_id", sa.Integer()),
sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.Column("modified_by", sa.String()),
sa.Column("changeset", sa.JSON),
sa.PrimaryKeyConstraint("lei", "version")
)


def downgrade() -> None:
op.drop_table("financial_institutions_history")
34 changes: 34 additions & 0 deletions db_revisions/versions/3f893e52d05c_240130_add_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""240130 add version
Revision ID: 3f893e52d05c
Revises: 6826f05140cd
Create Date: 2024-01-30 14:37:47.652233
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '3f893e52d05c'
down_revision: Union[str, None] = '6826f05140cd'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
with op.batch_alter_table("financial_institutions") as batch_op:
batch_op.add_column(sa.Column("version", type_=sa.Integer(), nullable=False, server_default=sa.text("1")))
batch_op.add_column(sa.Column("modified_by", sa.String()))
with op.batch_alter_table("fi_to_type_mapping") as batch_op:
batch_op.add_column(sa.Column("version", type_=sa.Integer(), nullable=False, server_default=sa.text("1")))
batch_op.add_column(sa.Column("modified_by", sa.String()))


def downgrade() -> None:
op.drop_column("financial_institutions", "version")
op.drop_column("financial_institutions", "modified_by")
op.drop_column("fi_to_type_mapping", "version")
op.drop_column("fi_to_type_mapping", "modified_by")
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""240131 fi type association history table
Revision ID: 8106d83ff594
Revises: 329c70502325
Create Date: 2024-01-31 10:23:21.163572
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '8106d83ff594'
down_revision: Union[str, None] = '329c70502325'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

def upgrade() -> None:
op.create_table(
"fi_to_type_mapping_history",
sa.Column("version", sa.Integer()),
sa.Column("fi_id", sa.String(), nullable=False),
sa.Column("type_id", sa.String(), nullable=False),
sa.Column("details", sa.String()),
sa.Column("modified_by", sa.String()),
sa.Column("event_time", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.Column("changeset", sa.JSON),
sa.PrimaryKeyConstraint("fi_id", "type_id", "version")
)


def downgrade() -> None:
op.drop_table("fi_to_type_mapping_history")
51 changes: 51 additions & 0 deletions src/entities/listeners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from sqlalchemy import Connection, Table, event, inspect
from sqlalchemy.orm import Mapper

from .models.dao import Base, FinancialInstitutionDao
from entities.engine.engine import engine


async def setup_fi_dao_listeners():
async with engine.begin() as connection:
fi_history, mapping_history = await connection.run_sync(
lambda conn: (
Table("financial_institutions_history", Base.metadata, autoload_with=conn),
Table("fi_to_type_mapping_history", Base.metadata, autoload_with=conn),
)
)

def insert_history(
mapper: Mapper[FinancialInstitutionDao], connection: Connection, target: FinancialInstitutionDao
):
new_version = target.version + 1 if target.version else 1
changes = {}
state = inspect(target)
for attr in state.attrs:
if attr.key == "event_time":
continue
attr_hist = attr.load_history()
if not attr_hist.has_changes():
continue
if attr.key == "sbl_institution_types":
old_types = [o.as_dict() for o in attr_hist.deleted]
new_types = [{**n.as_dict(), "version": new_version} for n in attr_hist.added]
changes[attr.key] = {"old": old_types, "new": new_types}
else:
changes[attr.key] = {"old": attr_hist.deleted, "new": attr_hist.added}
if changes:
target.version = new_version
for t in target.sbl_institution_types:
t.version = new_version
hist = target.__dict__.copy()
hist.pop("event_time", None)
history_columns = fi_history.columns.keys()
for key in hist.copy():
if key not in history_columns:
del hist[key]
hist["changeset"] = changes
types = [t.as_db_dict() for t in target.sbl_institution_types]
connection.execute(fi_history.insert().values(hist))
connection.execute(mapping_history.insert().values(types))

event.listen(FinancialInstitutionDao, "before_insert", insert_history)
event.listen(FinancialInstitutionDao, "before_update", insert_history)
18 changes: 14 additions & 4 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import datetime
from typing import List
from sqlalchemy import ForeignKey, func, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import ForeignKey, func, String, inspect
from sqlalchemy.orm import Mapped, mapped_column, relationship, DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs


Expand All @@ -16,14 +15,24 @@ class AuditMixin(object):

class SblTypeMappingDao(Base):
__tablename__ = "fi_to_type_mapping"
version: Mapped[int] = mapped_column(nullable=False, default=0)
__mapper_args__ = {"version_id_col": version, "version_id_generator": False}
lei: Mapped[str] = mapped_column("fi_id", ForeignKey("financial_institutions.lei"), primary_key=True)
type_id: Mapped[str] = mapped_column(ForeignKey("sbl_institution_type.id"), primary_key=True)
sbl_type: Mapped["SBLInstitutionTypeDao"] = relationship(lazy="selectin")
details: Mapped[str] = mapped_column(nullable=True)
details: Mapped[str] = mapped_column()
modified_by: Mapped[str] = mapped_column()
def as_db_dict(self):
data = {}
for attr, column in inspect(self.__class__).c.items():
data[column.name] = getattr(self, attr)
return data


class FinancialInstitutionDao(AuditMixin, Base):
__tablename__ = "financial_institutions"
version: Mapped[int] = mapped_column(nullable=False, default=0)
__mapper_args__ = {"version_id_col": version, "version_id_generator": False}
lei: Mapped[str] = mapped_column(unique=True, index=True, primary_key=True)
name: Mapped[str] = mapped_column(index=True)
is_active: Mapped[bool] = mapped_column(index=True)
Expand All @@ -49,6 +58,7 @@ class FinancialInstitutionDao(AuditMixin, Base):
top_holder_lei: Mapped[str] = mapped_column(String(20), nullable=True)
top_holder_legal_name: Mapped[str] = mapped_column(nullable=True)
top_holder_rssd_id: Mapped[int] = mapped_column(nullable=True)
modified_by: Mapped[str] = mapped_column()


class FinancialInstitutionDomainDao(AuditMixin, Base):
Expand Down
11 changes: 7 additions & 4 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from datetime import datetime
from typing import List

from sqlalchemy import select, func
from sqlalchemy.orm import joinedload
from sqlalchemy.ext.asyncio import AsyncSession

from regtech_api_commons.models import AuthenticatedUser

from .repo_utils import query_type

from entities.models import (
Expand Down Expand Up @@ -68,21 +71,21 @@ async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulator
return await query_type(session, FederalRegulatorDao)


async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao:
async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto, user: AuthenticatedUser) -> FinancialInstitutionDao:

Check failure on line 74 in src/entities/repos/institutions_repo.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

src/entities/repos/institutions_repo.py:74:121: E501 Line too long (133 > 120 characters)
async with session.begin():
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)

if "sbl_institution_types" in fi_data:
types_association = [
SblTypeMappingDao(type_id=t)
SblTypeMappingDao(type_id=t, lei=fi.lei, modified_by=user.id)
if isinstance(t, str)
else SblTypeMappingDao(type_id=t.id, details=t.details)
else SblTypeMappingDao(type_id=t.id, details=t.details, lei=fi.lei, modified_by=user.id)
for t in fi.sbl_institution_types
]
fi_data["sbl_institution_types"] = types_association

db_fi = await session.merge(FinancialInstitutionDao(**fi_data))
db_fi = await session.merge(FinancialInstitutionDao(**fi_data, modified_by=user.id, event_time=datetime.now()))
await session.flush()
await session.refresh(db_fi)
return db_fi
Expand Down
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from starlette.middleware.authentication import AuthenticationMiddleware
from alembic.config import Config
from alembic import command
from entities.listeners import setup_fi_dao_listeners

from routers import admin_router, institutions_router

Expand All @@ -33,6 +34,7 @@ async def lifespan(app_: FastAPI):
log.info("Starting up...")
log.info("run alembic upgrade head...")
run_migrations()
await setup_fi_dao_listeners()
yield
log.info("Shutting down...")

Expand Down
2 changes: 1 addition & 1 deletion src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def create_institution(
request: Request,
fi: FinancialInstitutionDto,
):
db_fi = await repo.upsert_institution(request.state.db_session, fi)
db_fi = await repo.upsert_institution(request.state.db_session, fi, request.user)
kc_id = oauth2_admin.upsert_group(fi.lei, fi.name)
return kc_id, db_fi

Expand Down

0 comments on commit 3e0320a

Please sign in to comment.