Skip to content

Commit

Permalink
Added endpoint to return lei data from list of leis (#20)
Browse files Browse the repository at this point in the history
Added endpoint to return lei data from request containing list of leis.
Added repo to gather data from db. Added simple error handling for when
none of leis from list are found.

---------

Co-authored-by: lchen-2101 <[email protected]>
  • Loading branch information
guffee23 and lchen-2101 authored Sep 15, 2023
1 parent 42fbf3a commit 29f1137
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 10 deletions.
19 changes: 19 additions & 0 deletions src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from itertools import chain

from fastapi import Query

from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand Down Expand Up @@ -35,3 +39,18 @@ def request_needs_domain_check(request: Request) -> bool:

async def email_domain_denied(session: AsyncSession, email: str) -> bool:
return not await repo.is_email_domain_allowed(session, email)


def parse_leis(leis: List[str] = Query(None)) -> Optional[List]:
"""
Parses leis from list of one or multiple strings to a list of
multiple distinct lei strings.
Returns empty list when nothing is passed in
Ex1: ['lei1,lei2'] -> ['lei1', 'lei2']
Ex2: ['lei1,lei2', 'lei3,lei4'] -> ['lei1','lei2','lei3','lei4']
"""

if leis:
return list(chain.from_iterable([x.split(",") for x in leis]))
else:
return None
10 changes: 8 additions & 2 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@


async def get_institutions(
session: AsyncSession, domain: str = "", page: int = 0, count: int = 100
session: AsyncSession,
leis: List[str] = None,
domain: str = "",
page: int = 0,
count: int = 100,
) -> List[FinancialInstitutionDao]:
async with session.begin():
stmt = (
Expand All @@ -23,7 +27,9 @@ async def get_institutions(
.limit(count)
.offset(page * count)
)
if d := domain.strip():
if leis:
stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis))
elif d := domain.strip():
search = "%{}%".format(d)
stmt = stmt.join(
FinancialInstitutionDomainDao,
Expand Down
6 changes: 5 additions & 1 deletion src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from http import HTTPStatus
from oauth2 import oauth2_admin
from util import Router
from dependencies import parse_leis
from typing import Annotated, List, Tuple
from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand All @@ -28,11 +29,14 @@ async def set_db(
@requires("authenticated")
async def get_institutions(
request: Request,
leis: List[str] = Depends(parse_leis),
domain: str = "",
page: int = 0,
count: int = 100,
):
return await repo.get_institutions(request.state.db_session, domain, page, count)
return await repo.get_institutions(
request.state.db_session, leis, domain, page, count
)


@router.post("/", response_model=Tuple[str, FinancialInstitutionDto])
Expand Down
33 changes: 26 additions & 7 deletions tests/entities/repos/test_institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,29 @@ async def setup(
self,
transaction_session: AsyncSession,
):
fi_dao = FinancialInstitutionDao(
fi_dao_123, fi_dao_456 = FinancialInstitutionDao(
name="Test Bank 123",
lei="TESTBANK123",
domains=[
FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")
FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123")
],
), FinancialInstitutionDao(
name="Test Bank 456",
lei="TESTBANK456",
domains=[
FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456")
],
)
transaction_session.add(fi_dao)
transaction_session.add(fi_dao_123)
transaction_session.add(fi_dao_456)
await transaction_session.commit()

async def test_get_institutions(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session)
assert len(res) == 1
assert len(res) == 2

async def test_get_institutions_by_domain(self, query_session: AsyncSession):
res = await repo.get_institutions(query_session, domain="test.bank")
res = await repo.get_institutions(query_session, domain="test.bank.1")
assert len(res) == 1

async def test_get_institutions_by_domain_not_existing(
Expand All @@ -40,21 +47,33 @@ async def test_get_institutions_by_domain_not_existing(
res = await repo.get_institutions(query_session, domain="testing.bank")
assert len(res) == 0

async def test_get_institutions_by_lei_list(self, query_session: AsyncSession):
res = await repo.get_institutions(
query_session, leis=["TESTBANK123", "TESTBANK456"]
)
assert len(res) == 2

async def test_get_institutions_by_lei_list_item_not_existing(
self, query_session: AsyncSession
):
res = await repo.get_institutions(query_session, leis=["NOTTESTBANK"])
assert len(res) == 0

async def test_add_institution(self, transaction_session: AsyncSession):
await repo.upsert_institution(
transaction_session,
FinancialInstitutionDao(name="New Bank 123", lei="NEWBANK123"),
)
res = await repo.get_institutions(transaction_session)
assert len(res) == 2
assert len(res) == 3

async def test_update_institution(self, transaction_session: AsyncSession):
await repo.upsert_institution(
transaction_session,
FinancialInstitutionDao(name="Test Bank 234", lei="TESTBANK123"),
)
res = await repo.get_institutions(transaction_session)
assert len(res) == 1
assert len(res) == 2
assert res[0].name == "Test Bank 234"

async def test_add_domains(
Expand Down

0 comments on commit 29f1137

Please sign in to comment.