Skip to content

Commit

Permalink
feat: replace lei checking dependencies with ones from api-commons (#165
Browse files Browse the repository at this point in the history
)

closes #161
  • Loading branch information
lchen-2101 authored May 20, 2024
1 parent 280b56b commit 9f06e33
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 222 deletions.
7 changes: 2 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

90 changes: 2 additions & 88 deletions src/regtech_user_fi_management/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
import functools

from http import HTTPStatus
from typing import Annotated
from fastapi import Depends, Query, Request, Response
from fastapi.types import DecoratedCallable
from fastapi import Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from itertools import chain
from regtech_user_fi_management.config import settings

from regtech_user_fi_management.entities.engine.engine import get_session
import regtech_user_fi_management.entities.repos.institutions_repo as repo
from starlette.authentication import AuthCredentials
from regtech_api_commons.models.auth import AuthenticatedUser
from regtech_api_commons.api.exceptions import RegTechHttpException
from regtech_api_commons.api.dependencies import get_email_domain


async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
Expand All @@ -29,82 +22,3 @@ async def check_domain(request: Request, session: Annotated[AsyncSession, Depend

async def email_domain_denied(session: AsyncSession, email: str) -> bool:
return not await repo.is_domain_allowed(session, email)


def parse_leis(leis: List[str] = Query(None)) -> Optional[List]:
"""
Parses leis from list of one or multiple strings to a list of
multiple distinct lei strings.
Returns empty list when nothing is passed in
Ex1: ['lei1,lei2'] -> ['lei1', 'lei2']
Ex2: ['lei1,lei2', 'lei3,lei4'] -> ['lei1','lei2','lei3','lei4']
"""

if leis:
return list(chain.from_iterable([x.split(",") for x in leis]))
else:
return None


def get_email_domain(email: str) -> str:
if email:
return email.split("@")[-1]
return None


def is_admin(auth: AuthCredentials):
return settings.admin_scopes.issubset(auth.scopes)


def lei_association_check(func: DecoratedCallable) -> DecoratedCallable:
@functools.wraps(func)
async def wrapper(request: Request, *args, **kwargs) -> Response:
lei = kwargs.get("lei")
user: AuthenticatedUser = request.user
auth: AuthCredentials = request.auth
if not is_admin(auth) and lei not in user.institutions:
raise RegTechHttpException(
HTTPStatus.FORBIDDEN, name="Request Forbidden", detail=f"LEI {lei} is not associated with the user."
)
return await func(request, *args, **kwargs)

return wrapper # type: ignore[return-value]


def fi_search_association_check(func: DecoratedCallable) -> DecoratedCallable:
def verify_leis(user: AuthenticatedUser, leis: List[str]) -> None:
if not set(filter(len, leis)).issubset(set(filter(len, user.institutions))):
raise RegTechHttpException(
HTTPStatus.FORBIDDEN,
name="Request Forbidden",
detail=f"Institutions query with LEIs ({leis}) not associated with user is forbidden.",
)

def verify_domain(user: AuthenticatedUser, domain: str) -> None:
if domain != get_email_domain(user.email):
raise RegTechHttpException(
HTTPStatus.FORBIDDEN,
name="Request Forbidden",
detail=f"Institutions query with domain ({domain}) not associated with user is forbidden.",
)

@functools.wraps(func)
async def wrapper(request: Request, *args, **kwargs) -> Response:
user: AuthenticatedUser = request.user
auth: AuthCredentials = request.auth
if not is_admin(auth):
leis = kwargs.get("leis")
domain = kwargs.get("domain")
if leis:
verify_leis(user, leis)
elif domain:
verify_domain(user, domain)
elif not leis and not domain:
raise RegTechHttpException(
HTTPStatus.FORBIDDEN,
name="Request Forbidden",
detail="Retrieving institutions without filter is forbidden.",
)
return await func(request=request, *args, **kwargs)

return wrapper # type: ignore[return-value]
34 changes: 22 additions & 12 deletions src/regtech_user_fi_management/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
from regtech_api_commons.api.router_wrapper import Router
from regtech_user_fi_management.dependencies import (
check_domain,
parse_leis,
get_email_domain,
lei_association_check,
fi_search_association_check,
)
from typing import Annotated, List, Tuple, Literal
from regtech_user_fi_management.entities.engine.engine import get_session
Expand All @@ -30,6 +26,12 @@
from starlette.authentication import requires
from regtech_api_commons.models.auth import AuthenticatedUser
from regtech_api_commons.api.exceptions import RegTechHttpException
from regtech_api_commons.api.dependencies import (
verify_institution_search,
verify_user_lei_relation,
parse_leis,
get_email_domain,
)

oauth2_admin = OAuth2Admin(kc_settings)

Expand All @@ -43,9 +45,10 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_
router = Router(dependencies=[Depends(set_db)])


@router.get("/", response_model=List[FinancialInstitutionWithRelationsDto])
@router.get(
"/", response_model=List[FinancialInstitutionWithRelationsDto], dependencies=[Depends(verify_institution_search)]
)
@requires("authenticated")
@fi_search_association_check
async def get_institutions(
request: Request,
leis: List[str] = Depends(parse_leis),
Expand Down Expand Up @@ -104,9 +107,10 @@ async def get_federal_regulators(request: Request):
return await repo.get_federal_regulators(request.state.db_session)


@router.get("/{lei}", response_model=FinancialInstitutionWithRelationsDto)
@router.get(
"/{lei}", response_model=FinancialInstitutionWithRelationsDto, dependencies=[Depends(verify_user_lei_relation)]
)
@requires("authenticated")
@lei_association_check
async def get_institution(
request: Request,
lei: str,
Expand All @@ -117,9 +121,12 @@ async def get_institution(
return res


@router.get("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@router.get(
"/{lei}/types/{type}",
response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None,
dependencies=[Depends(verify_user_lei_relation)],
)
@requires("authenticated")
@lei_association_check
async def get_types(request: Request, response: Response, lei: str, type: InstitutionType):
match type:
case "sbl":
Expand All @@ -133,9 +140,12 @@ async def get_types(request: Request, response: Response, lei: str, type: Instit
)


@router.put("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@router.put(
"/{lei}/types/{type}",
response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None,
dependencies=[Depends(verify_user_lei_relation)],
)
@requires("authenticated")
@lei_association_check
async def update_types(
request: Request, response: Response, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto
):
Expand Down
117 changes: 0 additions & 117 deletions tests/app/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from http import HTTPStatus
from typing import List
from fastapi import HTTPException, Request
from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncSession
from regtech_user_fi_management.dependencies import lei_association_check, fi_search_association_check
from regtech_api_commons.api.exceptions import RegTechHttpException
from starlette.authentication import AuthCredentials

import pytest

Expand Down Expand Up @@ -35,114 +29,3 @@ async def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession)

assert await email_domain_denied(mock_session, allowed_domain) is False
domain_allowed_mock.assert_called_once_with(mock_session, allowed_domain)


async def test_lei_association_check_matching_lei(mock_request: Request):
@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

await method_to_wrap(mock_request, lei="TESTBANK123")


async def test_lei_association_check_is_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

await method_to_wrap(mock_request, lei="TESTBANK1234")


async def test_lei_association_check_not_matching(mock_request: Request):
@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, lei="NOTMYBANK")
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail
assert isinstance(e.value, RegTechHttpException)


async def test_fi_search_association_check_matching_lei(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, leis=["TESTBANK123"])


async def test_fi_search_association_check_invalid_lei(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, leis=["NOTMYBANK"])
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail
assert isinstance(e.value, RegTechHttpException)


async def test_fi_search_association_check_matching_domain(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, domain="local.host")


async def test_fi_search_association_check_invalid_domain(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, domain="not.myhost")
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail


async def test_fi_search_association_check_no_filter(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request)
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "without filter" in e.value.detail
assert isinstance(e.value, RegTechHttpException)


async def test_fi_search_association_check_lei_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, leis=["TESTBANK123", "ANOTHERBANK", "NOTMYBANK"])


async def test_fi_search_association_check_domain_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, domain="not.myhost")


async def test_fi_search_association_check_no_filter_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request)

0 comments on commit 9f06e33

Please sign in to comment.