Skip to content

Commit

Permalink
fix: updated upsert to flush and refresh the dao before returning (#86)
Browse files Browse the repository at this point in the history
closes #85
  • Loading branch information
lchen-2101 authored Jan 8, 2024
1 parent 775a81a commit bef4871
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class FinancialInstitutionDao(AuditMixin, Base):
name: Mapped[str] = mapped_column(index=True)
is_active: Mapped[bool] = mapped_column(index=True)
domains: Mapped[List["FinancialInstitutionDomainDao"]] = relationship(
"FinancialInstitutionDomainDao", back_populates="fi"
"FinancialInstitutionDomainDao", back_populates="fi", lazy="selectin"
)
tax_id: Mapped[str] = mapped_column(String(9), unique=True, nullable=True)
rssd_id: Mapped[int] = mapped_column(unique=True, nullable=True)
Expand Down
13 changes: 3 additions & 10 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,11 @@ async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulator

async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao:
async with session.begin():
stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei)
res = await session.execute(stmt)
db_fi = res.scalar_one_or_none()
fi_data = fi.__dict__.copy()
fi_data.pop("_sa_instance_state", None)
if db_fi is None:
db_fi = FinancialInstitutionDao(**fi_data)
session.add(db_fi)
else:
for key, value in fi_data.items():
setattr(db_fi, key, value)
await session.commit()
db_fi = await session.merge(FinancialInstitutionDao(**fi_data))
await session.flush([db_fi])
await session.refresh(db_fi)
return db_fi


Expand Down
3 changes: 2 additions & 1 deletion tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ async def test_get_institutions_by_lei_list_item_not_existing(self, query_sessio
assert len(res) == 0

async def test_add_institution(self, transaction_session: AsyncSession):
await repo.upsert_institution(
db_fi = await repo.upsert_institution(
transaction_session,
FinancialInstitutionDao(
name="New Bank 123",
Expand All @@ -212,6 +212,7 @@ async def test_add_institution(self, transaction_session: AsyncSession):
top_holder_rssd_id=876543,
),
)
assert db_fi.domains == []
res = await repo.get_institutions(transaction_session)
assert len(res) == 4

Expand Down

0 comments on commit bef4871

Please sign in to comment.