diff --git a/src/dependencies.py b/src/dependencies.py index 384c7c9..e9e178d 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -2,6 +2,10 @@ from typing import Annotated from fastapi import Depends, HTTPException, Request from sqlalchemy.ext.asyncio import AsyncSession +from typing import List, Optional +from itertools import chain + +from fastapi import Query from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -35,3 +39,18 @@ def request_needs_domain_check(request: Request) -> bool: async def email_domain_denied(session: AsyncSession, email: str) -> bool: return not await repo.is_email_domain_allowed(session, email) + + +def parse_leis(leis: List[str] = Query(None)) -> Optional[List]: + """ + Parses leis from list of one or multiple strings to a list of + multiple distinct lei strings. + Returns empty list when nothing is passed in + Ex1: ['lei1,lei2'] -> ['lei1', 'lei2'] + Ex2: ['lei1,lei2', 'lei3,lei4'] -> ['lei1','lei2','lei3','lei4'] + """ + + if leis: + return list(chain.from_iterable([x.split(",") for x in leis])) + else: + return None diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 887809b..fb7ef88 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -14,7 +14,11 @@ async def get_institutions( - session: AsyncSession, domain: str = "", page: int = 0, count: int = 100 + session: AsyncSession, + leis: List[str] = None, + domain: str = "", + page: int = 0, + count: int = 100, ) -> List[FinancialInstitutionDao]: async with session.begin(): stmt = ( @@ -23,7 +27,9 @@ async def get_institutions( .limit(count) .offset(page * count) ) - if d := domain.strip(): + if leis: + stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis)) + elif d := domain.strip(): search = "%{}%".format(d) stmt = stmt.join( FinancialInstitutionDomainDao, diff --git a/src/routers/institutions.py b/src/routers/institutions.py index e54ed32..440ca13 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -2,6 +2,7 @@ from http import HTTPStatus from oauth2 import oauth2_admin from util import Router +from dependencies import parse_leis from typing import Annotated, List, Tuple from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -28,11 +29,14 @@ async def set_db( @requires("authenticated") async def get_institutions( request: Request, + leis: List[str] = Depends(parse_leis), domain: str = "", page: int = 0, count: int = 100, ): - return await repo.get_institutions(request.state.db_session, domain, page, count) + return await repo.get_institutions( + request.state.db_session, leis, domain, page, count + ) @router.post("/", response_model=Tuple[str, FinancialInstitutionDto]) diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index f042c7f..584fd31 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -16,22 +16,29 @@ async def setup( self, transaction_session: AsyncSession, ): - fi_dao = FinancialInstitutionDao( + fi_dao_123, fi_dao_456 = FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") + FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123") + ], + ), FinancialInstitutionDao( + name="Test Bank 456", + lei="TESTBANK456", + domains=[ + FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456") ], ) - transaction_session.add(fi_dao) + transaction_session.add(fi_dao_123) + transaction_session.add(fi_dao_456) await transaction_session.commit() async def test_get_institutions(self, query_session: AsyncSession): res = await repo.get_institutions(query_session) - assert len(res) == 1 + assert len(res) == 2 async def test_get_institutions_by_domain(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, domain="test.bank") + res = await repo.get_institutions(query_session, domain="test.bank.1") assert len(res) == 1 async def test_get_institutions_by_domain_not_existing( @@ -40,13 +47,25 @@ async def test_get_institutions_by_domain_not_existing( res = await repo.get_institutions(query_session, domain="testing.bank") assert len(res) == 0 + async def test_get_institutions_by_lei_list(self, query_session: AsyncSession): + res = await repo.get_institutions( + query_session, leis=["TESTBANK123", "TESTBANK456"] + ) + assert len(res) == 2 + + async def test_get_institutions_by_lei_list_item_not_existing( + self, query_session: AsyncSession + ): + res = await repo.get_institutions(query_session, leis=["NOTTESTBANK"]) + assert len(res) == 0 + async def test_add_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( transaction_session, FinancialInstitutionDao(name="New Bank 123", lei="NEWBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 2 + assert len(res) == 3 async def test_update_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( @@ -54,7 +73,7 @@ async def test_update_institution(self, transaction_session: AsyncSession): FinancialInstitutionDao(name="Test Bank 234", lei="TESTBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 1 + assert len(res) == 2 assert res[0].name == "Test Bank 234" async def test_add_domains(