From bef48712cb30496c61c3d48c17ccc4099ff3b88e Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Mon, 8 Jan 2024 13:30:37 -0500 Subject: [PATCH] fix: updated upsert to flush and refresh the dao before returning (#86) closes #85 --- src/entities/models/dao.py | 2 +- src/entities/repos/institutions_repo.py | 13 +++---------- tests/entities/repos/test_institutions_repo.py | 3 ++- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index e4bf81c..f8c1eba 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -20,7 +20,7 @@ class FinancialInstitutionDao(AuditMixin, Base): name: Mapped[str] = mapped_column(index=True) is_active: Mapped[bool] = mapped_column(index=True) domains: Mapped[List["FinancialInstitutionDomainDao"]] = relationship( - "FinancialInstitutionDomainDao", back_populates="fi" + "FinancialInstitutionDomainDao", back_populates="fi", lazy="selectin" ) tax_id: Mapped[str] = mapped_column(String(9), unique=True, nullable=True) rssd_id: Mapped[int] = mapped_column(unique=True, nullable=True) diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 177ca03..3bb1340 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -69,18 +69,11 @@ async def get_federal_regulators(session: AsyncSession) -> List[FederalRegulator async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao: async with session.begin(): - stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei) - res = await session.execute(stmt) - db_fi = res.scalar_one_or_none() fi_data = fi.__dict__.copy() fi_data.pop("_sa_instance_state", None) - if db_fi is None: - db_fi = FinancialInstitutionDao(**fi_data) - session.add(db_fi) - else: - for key, value in fi_data.items(): - setattr(db_fi, key, value) - await session.commit() + db_fi = await session.merge(FinancialInstitutionDao(**fi_data)) + await session.flush([db_fi]) + await session.refresh(db_fi) return db_fi diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index e8c9d8e..74718ed 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -188,7 +188,7 @@ async def test_get_institutions_by_lei_list_item_not_existing(self, query_sessio assert len(res) == 0 async def test_add_institution(self, transaction_session: AsyncSession): - await repo.upsert_institution( + db_fi = await repo.upsert_institution( transaction_session, FinancialInstitutionDao( name="New Bank 123", @@ -212,6 +212,7 @@ async def test_add_institution(self, transaction_session: AsyncSession): top_holder_rssd_id=876543, ), ) + assert db_fi.domains == [] res = await repo.get_institutions(transaction_session) assert len(res) == 4