diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 8605ce6..af25fcf 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -23,6 +23,7 @@ async def get_institutions( async with session.begin(): stmt = ( select(FinancialInstitutionDao) + .join(FinancialInstitutionDomainDao) .options(joinedload(FinancialInstitutionDao.domains)) .limit(count) .offset(page * count) @@ -30,11 +31,7 @@ async def get_institutions( if leis is not None: stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis)) elif d := domain.strip(): - search = "%{}%".format(d) - stmt = stmt.join( - FinancialInstitutionDomainDao, - FinancialInstitutionDao.lei == FinancialInstitutionDomainDao.lei, - ).filter(FinancialInstitutionDomainDao.domain.like(search)) + stmt = stmt.filter(FinancialInstitutionDomainDao.domain == d) res = await session.scalars(stmt) return res.unique().all() diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 3b7279a..75ad6db 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -16,7 +16,7 @@ async def setup( self, transaction_session: AsyncSession, ): - fi_dao_123, fi_dao_456 = FinancialInstitutionDao( + fi_dao_123, fi_dao_456, fi_dao_sub_456 = FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", domains=[FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123")], @@ -24,18 +24,34 @@ async def setup( name="Test Bank 456", lei="TESTBANK456", domains=[FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456")], + ), FinancialInstitutionDao( + name="Test Sub Bank 456", + lei="TESTSUBBANK456", + domains=[FinancialInstitutionDomainDao(domain="sub.test.bank.2", lei="TESTSUBBANK456")], ) transaction_session.add(fi_dao_123) transaction_session.add(fi_dao_456) + transaction_session.add(fi_dao_sub_456) await transaction_session.commit() async def test_get_institutions(self, query_session: AsyncSession): res = await repo.get_institutions(query_session) - assert len(res) == 2 + assert len(res) == 3 - async def test_get_institutions_by_domain(self, query_session: AsyncSession): + async def test_get_institutions_by_domain(self, query_session: AsyncSession): + #verify 'generic' domain queries don't work + res = await repo.get_institutions(query_session, domain="bank") + assert len(res) == 0 + res = await repo.get_institutions(query_session, domain="test.bank.1") assert len(res) == 1 + + #shouldn't find sub.test.bank.2 + res = await repo.get_institutions(query_session, domain="test.bank.2") + assert len(res) == 1 + + res = await repo.get_institutions(query_session, domain="sub.test.bank.2") + assert len(res) == 1 async def test_get_institutions_by_domain_not_existing(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, domain="testing.bank") @@ -63,7 +79,7 @@ async def test_update_institution(self, transaction_session: AsyncSession): FinancialInstitutionDao(name="Test Bank 234", lei="TESTBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 2 + assert len(res) == 3 assert res[0].name == "Test Bank 234" async def test_add_domains(self, transaction_session: AsyncSession, query_session: AsyncSession):