diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 8605ce6..4442b50 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -30,11 +30,7 @@ async def get_institutions( if leis is not None: stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis)) elif d := domain.strip(): - search = "%{}%".format(d) - stmt = stmt.join( - FinancialInstitutionDomainDao, - FinancialInstitutionDao.lei == FinancialInstitutionDomainDao.lei, - ).filter(FinancialInstitutionDomainDao.domain.like(search)) + stmt = stmt.join(FinancialInstitutionDomainDao).filter(FinancialInstitutionDomainDao.domain == d) res = await session.scalars(stmt) return res.unique().all() diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 3b7279a..6af55cd 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -16,27 +16,47 @@ async def setup( self, transaction_session: AsyncSession, ): - fi_dao_123, fi_dao_456 = FinancialInstitutionDao( - name="Test Bank 123", - lei="TESTBANK123", - domains=[FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123")], - ), FinancialInstitutionDao( - name="Test Bank 456", - lei="TESTBANK456", - domains=[FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456")], + fi_dao_123, fi_dao_456, fi_dao_sub_456 = ( + FinancialInstitutionDao( + name="Test Bank 123", + lei="TESTBANK123", + domains=[FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123")], + ), + FinancialInstitutionDao( + name="Test Bank 456", + lei="TESTBANK456", + domains=[FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456")], + ), + FinancialInstitutionDao( + name="Test Sub Bank 456", + lei="TESTSUBBANK456", + domains=[FinancialInstitutionDomainDao(domain="sub.test.bank.2", lei="TESTSUBBANK456")], + ), ) transaction_session.add(fi_dao_123) transaction_session.add(fi_dao_456) + transaction_session.add(fi_dao_sub_456) await transaction_session.commit() async def test_get_institutions(self, query_session: AsyncSession): res = await repo.get_institutions(query_session) - assert len(res) == 2 + assert len(res) == 3 async def test_get_institutions_by_domain(self, query_session: AsyncSession): + # verify 'generic' domain queries don't work + res = await repo.get_institutions(query_session, domain="bank") + assert len(res) == 0 + res = await repo.get_institutions(query_session, domain="test.bank.1") assert len(res) == 1 + # shouldn't find sub.test.bank.2 + res = await repo.get_institutions(query_session, domain="test.bank.2") + assert len(res) == 1 + + res = await repo.get_institutions(query_session, domain="sub.test.bank.2") + assert len(res) == 1 + async def test_get_institutions_by_domain_not_existing(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, domain="testing.bank") assert len(res) == 0 @@ -55,7 +75,7 @@ async def test_add_institution(self, transaction_session: AsyncSession): FinancialInstitutionDao(name="New Bank 123", lei="NEWBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 3 + assert len(res) == 4 async def test_update_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( @@ -63,7 +83,7 @@ async def test_update_institution(self, transaction_session: AsyncSession): FinancialInstitutionDao(name="Test Bank 234", lei="TESTBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 2 + assert len(res) == 3 assert res[0].name == "Test Bank 234" async def test_add_domains(self, transaction_session: AsyncSession, query_session: AsyncSession):