From 9f06e33275426fc458fd6a6499f6aef351159d65 Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Mon, 20 May 2024 10:54:46 -0400 Subject: [PATCH] feat: replace lei checking dependencies with ones from api-commons (#165) closes #161 --- poetry.lock | 7 +- .../dependencies.py | 90 +------------- .../routers/institutions.py | 34 +++-- tests/app/test_dependencies.py | 117 ------------------ 4 files changed, 26 insertions(+), 222 deletions(-) diff --git a/poetry.lock b/poetry.lock index 6da894c..841bb85 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiosqlite" @@ -1097,7 +1097,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -1106,8 +1105,6 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -1563,7 +1560,7 @@ uvicorn = "^0.29.0" type = "git" url = "https://github.com/cfpb/regtech-api-commons.git" reference = "HEAD" -resolved_reference = "b892c5036de47a9d88db53dcb9ce1f48ef5274b0" +resolved_reference = "50b600a7a4e87185d5503d782945d502bb0cc8b8" [[package]] name = "regtech-regex" diff --git a/src/regtech_user_fi_management/dependencies.py b/src/regtech_user_fi_management/dependencies.py index 512d03f..7f49649 100644 --- a/src/regtech_user_fi_management/dependencies.py +++ b/src/regtech_user_fi_management/dependencies.py @@ -1,19 +1,12 @@ -import functools - from http import HTTPStatus from typing import Annotated -from fastapi import Depends, Query, Request, Response -from fastapi.types import DecoratedCallable +from fastapi import Depends, Request from sqlalchemy.ext.asyncio import AsyncSession -from typing import List, Optional -from itertools import chain -from regtech_user_fi_management.config import settings from regtech_user_fi_management.entities.engine.engine import get_session import regtech_user_fi_management.entities.repos.institutions_repo as repo -from starlette.authentication import AuthCredentials -from regtech_api_commons.models.auth import AuthenticatedUser from regtech_api_commons.api.exceptions import RegTechHttpException +from regtech_api_commons.api.dependencies import get_email_domain async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None: @@ -29,82 +22,3 @@ async def check_domain(request: Request, session: Annotated[AsyncSession, Depend async def email_domain_denied(session: AsyncSession, email: str) -> bool: return not await repo.is_domain_allowed(session, email) - - -def parse_leis(leis: List[str] = Query(None)) -> Optional[List]: - """ - Parses leis from list of one or multiple strings to a list of - multiple distinct lei strings. - Returns empty list when nothing is passed in - Ex1: ['lei1,lei2'] -> ['lei1', 'lei2'] - Ex2: ['lei1,lei2', 'lei3,lei4'] -> ['lei1','lei2','lei3','lei4'] - """ - - if leis: - return list(chain.from_iterable([x.split(",") for x in leis])) - else: - return None - - -def get_email_domain(email: str) -> str: - if email: - return email.split("@")[-1] - return None - - -def is_admin(auth: AuthCredentials): - return settings.admin_scopes.issubset(auth.scopes) - - -def lei_association_check(func: DecoratedCallable) -> DecoratedCallable: - @functools.wraps(func) - async def wrapper(request: Request, *args, **kwargs) -> Response: - lei = kwargs.get("lei") - user: AuthenticatedUser = request.user - auth: AuthCredentials = request.auth - if not is_admin(auth) and lei not in user.institutions: - raise RegTechHttpException( - HTTPStatus.FORBIDDEN, name="Request Forbidden", detail=f"LEI {lei} is not associated with the user." - ) - return await func(request, *args, **kwargs) - - return wrapper # type: ignore[return-value] - - -def fi_search_association_check(func: DecoratedCallable) -> DecoratedCallable: - def verify_leis(user: AuthenticatedUser, leis: List[str]) -> None: - if not set(filter(len, leis)).issubset(set(filter(len, user.institutions))): - raise RegTechHttpException( - HTTPStatus.FORBIDDEN, - name="Request Forbidden", - detail=f"Institutions query with LEIs ({leis}) not associated with user is forbidden.", - ) - - def verify_domain(user: AuthenticatedUser, domain: str) -> None: - if domain != get_email_domain(user.email): - raise RegTechHttpException( - HTTPStatus.FORBIDDEN, - name="Request Forbidden", - detail=f"Institutions query with domain ({domain}) not associated with user is forbidden.", - ) - - @functools.wraps(func) - async def wrapper(request: Request, *args, **kwargs) -> Response: - user: AuthenticatedUser = request.user - auth: AuthCredentials = request.auth - if not is_admin(auth): - leis = kwargs.get("leis") - domain = kwargs.get("domain") - if leis: - verify_leis(user, leis) - elif domain: - verify_domain(user, domain) - elif not leis and not domain: - raise RegTechHttpException( - HTTPStatus.FORBIDDEN, - name="Request Forbidden", - detail="Retrieving institutions without filter is forbidden.", - ) - return await func(request=request, *args, **kwargs) - - return wrapper # type: ignore[return-value] diff --git a/src/regtech_user_fi_management/routers/institutions.py b/src/regtech_user_fi_management/routers/institutions.py index 515b94e..6fd4235 100644 --- a/src/regtech_user_fi_management/routers/institutions.py +++ b/src/regtech_user_fi_management/routers/institutions.py @@ -5,10 +5,6 @@ from regtech_api_commons.api.router_wrapper import Router from regtech_user_fi_management.dependencies import ( check_domain, - parse_leis, - get_email_domain, - lei_association_check, - fi_search_association_check, ) from typing import Annotated, List, Tuple, Literal from regtech_user_fi_management.entities.engine.engine import get_session @@ -30,6 +26,12 @@ from starlette.authentication import requires from regtech_api_commons.models.auth import AuthenticatedUser from regtech_api_commons.api.exceptions import RegTechHttpException +from regtech_api_commons.api.dependencies import ( + verify_institution_search, + verify_user_lei_relation, + parse_leis, + get_email_domain, +) oauth2_admin = OAuth2Admin(kc_settings) @@ -43,9 +45,10 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_ router = Router(dependencies=[Depends(set_db)]) -@router.get("/", response_model=List[FinancialInstitutionWithRelationsDto]) +@router.get( + "/", response_model=List[FinancialInstitutionWithRelationsDto], dependencies=[Depends(verify_institution_search)] +) @requires("authenticated") -@fi_search_association_check async def get_institutions( request: Request, leis: List[str] = Depends(parse_leis), @@ -104,9 +107,10 @@ async def get_federal_regulators(request: Request): return await repo.get_federal_regulators(request.state.db_session) -@router.get("/{lei}", response_model=FinancialInstitutionWithRelationsDto) +@router.get( + "/{lei}", response_model=FinancialInstitutionWithRelationsDto, dependencies=[Depends(verify_user_lei_relation)] +) @requires("authenticated") -@lei_association_check async def get_institution( request: Request, lei: str, @@ -117,9 +121,12 @@ async def get_institution( return res -@router.get("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None) +@router.get( + "/{lei}/types/{type}", + response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None, + dependencies=[Depends(verify_user_lei_relation)], +) @requires("authenticated") -@lei_association_check async def get_types(request: Request, response: Response, lei: str, type: InstitutionType): match type: case "sbl": @@ -133,9 +140,12 @@ async def get_types(request: Request, response: Response, lei: str, type: Instit ) -@router.put("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None) +@router.put( + "/{lei}/types/{type}", + response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None, + dependencies=[Depends(verify_user_lei_relation)], +) @requires("authenticated") -@lei_association_check async def update_types( request: Request, response: Response, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto ): diff --git a/tests/app/test_dependencies.py b/tests/app/test_dependencies.py index 49ad2e1..451ff63 100644 --- a/tests/app/test_dependencies.py +++ b/tests/app/test_dependencies.py @@ -1,11 +1,5 @@ -from http import HTTPStatus -from typing import List -from fastapi import HTTPException, Request from pytest_mock import MockerFixture from sqlalchemy.ext.asyncio import AsyncSession -from regtech_user_fi_management.dependencies import lei_association_check, fi_search_association_check -from regtech_api_commons.api.exceptions import RegTechHttpException -from starlette.authentication import AuthCredentials import pytest @@ -35,114 +29,3 @@ async def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession) assert await email_domain_denied(mock_session, allowed_domain) is False domain_allowed_mock.assert_called_once_with(mock_session, allowed_domain) - - -async def test_lei_association_check_matching_lei(mock_request: Request): - @lei_association_check - async def method_to_wrap(request: Request, lei: str): - pass - - await method_to_wrap(mock_request, lei="TESTBANK123") - - -async def test_lei_association_check_is_admin(mock_request: Request): - mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]) - - @lei_association_check - async def method_to_wrap(request: Request, lei: str): - pass - - await method_to_wrap(mock_request, lei="TESTBANK1234") - - -async def test_lei_association_check_not_matching(mock_request: Request): - @lei_association_check - async def method_to_wrap(request: Request, lei: str): - pass - - with pytest.raises(HTTPException) as e: - await method_to_wrap(mock_request, lei="NOTMYBANK") - assert e.value.status_code == HTTPStatus.FORBIDDEN - assert "not associated" in e.value.detail - assert isinstance(e.value, RegTechHttpException) - - -async def test_fi_search_association_check_matching_lei(mock_request: Request): - @fi_search_association_check - async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): - pass - - await method_to_wrap(mock_request, leis=["TESTBANK123"]) - - -async def test_fi_search_association_check_invalid_lei(mock_request: Request): - @fi_search_association_check - async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): - pass - - with pytest.raises(HTTPException) as e: - await method_to_wrap(mock_request, leis=["NOTMYBANK"]) - assert e.value.status_code == HTTPStatus.FORBIDDEN - assert "not associated" in e.value.detail - assert isinstance(e.value, RegTechHttpException) - - -async def test_fi_search_association_check_matching_domain(mock_request: Request): - @fi_search_association_check - async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): - pass - - await method_to_wrap(mock_request, domain="local.host") - - -async def test_fi_search_association_check_invalid_domain(mock_request: Request): - @fi_search_association_check - async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): - pass - - with pytest.raises(HTTPException) as e: - await method_to_wrap(mock_request, domain="not.myhost") - assert e.value.status_code == HTTPStatus.FORBIDDEN - assert "not associated" in e.value.detail - - -async def test_fi_search_association_check_no_filter(mock_request: Request): - @fi_search_association_check - async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): - pass - - with pytest.raises(HTTPException) as e: - await method_to_wrap(mock_request) - assert e.value.status_code == HTTPStatus.FORBIDDEN - assert "without filter" in e.value.detail - assert isinstance(e.value, RegTechHttpException) - - -async def test_fi_search_association_check_lei_admin(mock_request: Request): - mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]) - - @fi_search_association_check - async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): - pass - - await method_to_wrap(mock_request, leis=["TESTBANK123", "ANOTHERBANK", "NOTMYBANK"]) - - -async def test_fi_search_association_check_domain_admin(mock_request: Request): - mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]) - - @fi_search_association_check - async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): - pass - - await method_to_wrap(mock_request, domain="not.myhost") - - -async def test_fi_search_association_check_no_filter_admin(mock_request: Request): - mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]) - - @fi_search_association_check - async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): - pass - - await method_to_wrap(mock_request)