diff --git a/src/.env.local b/src/.env.local index 69a344b..af5dd0e 100644 --- a/src/.env.local +++ b/src/.env.local @@ -15,4 +15,5 @@ INST_DB_HOST=localhost:5432 INST_DB_SCHEMA=public JWT_OPTS_VERIFY_AT_HASH="false" JWT_OPTS_VERIFY_AUD="false" -JWT_OPTS_VERIFY_ISS="false" \ No newline at end of file +JWT_OPTS_VERIFY_ISS="false" +ADMIN_SCOPES=["query-groups","manage-users"] \ No newline at end of file diff --git a/src/config.py b/src/config.py index 6b8cae6..817ac53 100644 --- a/src/config.py +++ b/src/config.py @@ -1,6 +1,6 @@ import os from urllib import parse -from typing import Any +from typing import Any, Set from pydantic import field_validator, ValidationInfo from pydantic.networks import PostgresDsn @@ -24,6 +24,7 @@ class Settings(BaseSettings): inst_db_host: str inst_db_scheme: str = "postgresql+asyncpg" inst_conn: PostgresDsn | None = None + admin_scopes: Set[str] = set(["query-groups", "manage-users"]) def __init__(self, **data): super().__init__(**data) diff --git a/src/dependencies.py b/src/dependencies.py index 90a3c94..3516557 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -1,14 +1,18 @@ +import functools + from http import HTTPStatus from typing import Annotated -from fastapi import Depends, HTTPException, Request +from fastapi import Depends, Query, HTTPException, Request, Response +from fastapi.types import DecoratedCallable from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional from itertools import chain - -from fastapi import Query +from config import settings from entities.engine import get_session from entities.repos import institutions_repo as repo +from starlette.authentication import AuthCredentials +from regtech_api_commons.models.auth import AuthenticatedUser async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None: @@ -41,3 +45,53 @@ def get_email_domain(email: str) -> str: if email: return email.split("@")[-1] return None + + +def is_admin(auth: AuthCredentials): + return settings.admin_scopes.issubset(auth.scopes) + + +def lei_association_check(func: DecoratedCallable) -> DecoratedCallable: + @functools.wraps(func) + async def wrapper(request: Request, *args, **kwargs) -> Response: + lei = kwargs.get("lei") + user: AuthenticatedUser = request.user + auth: AuthCredentials = request.auth + if not is_admin(auth) and lei not in user.institutions: + raise HTTPException(HTTPStatus.FORBIDDEN, detail=f"LEI {lei} is not associated with the user.") + return await func(request, *args, **kwargs) + + return wrapper # type: ignore[return-value] + + +def fi_search_association_check(func: DecoratedCallable) -> DecoratedCallable: + def verify_leis(user: AuthenticatedUser, leis: List[str]) -> None: + if not set(filter(len, leis)).issubset(set(filter(len, user.institutions))): + raise HTTPException( + HTTPStatus.FORBIDDEN, + detail=f"Institutions query with LEIs ({leis}) not associated with user is forbidden.", + ) + + def verify_domain(user: AuthenticatedUser, domain: str) -> None: + if domain != get_email_domain(user.email): + raise HTTPException( + HTTPStatus.FORBIDDEN, + detail=f"Institutions query with domain ({domain}) not associated with user is forbidden.", + ) + + @functools.wraps(func) + async def wrapper(request: Request, *args, **kwargs) -> Response: + user: AuthenticatedUser = request.user + auth: AuthCredentials = request.auth + if not is_admin(auth): + leis = kwargs.get("leis") + domain = kwargs.get("domain") + if leis: + verify_leis(user, leis) + elif domain: + verify_domain(user, domain) + elif not leis and not domain: + raise HTTPException(HTTPStatus.FORBIDDEN, detail="Retrieving institutions without filter is forbidden.") + return await func(request=request, *args, **kwargs) + + return wrapper # type: ignore[return-value] diff --git a/src/routers/institutions.py b/src/routers/institutions.py index 347838a..c9f339e 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -3,7 +3,7 @@ from regtech_api_commons.oauth2.oauth2_admin import OAuth2Admin from config import kc_settings from regtech_api_commons.api import Router -from dependencies import check_domain, parse_leis, get_email_domain +from dependencies import check_domain, parse_leis, get_email_domain, lei_association_check, fi_search_association_check from typing import Annotated, List, Tuple, Literal from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -38,6 +38,7 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_ @router.get("/", response_model=List[FinancialInstitutionWithRelationsDto]) @requires("authenticated") +@fi_search_association_check async def get_institutions( request: Request, leis: List[str] = Depends(parse_leis), @@ -98,6 +99,7 @@ async def get_federal_regulators(request: Request): @router.get("/{lei}", response_model=FinancialInstitutionWithRelationsDto) @requires("authenticated") +@lei_association_check async def get_institution( request: Request, lei: str, @@ -110,6 +112,7 @@ async def get_institution( @router.get("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None) @requires("authenticated") +@lei_association_check async def get_types(request: Request, response: Response, lei: str, type: InstitutionType): match type: case "sbl": @@ -123,6 +126,7 @@ async def get_types(request: Request, response: Response, lei: str, type: Instit @router.put("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None) @requires("authenticated") +@lei_association_check async def update_types( request: Request, response: Response, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto ): diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index 50cb95e..8220607 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -32,6 +32,26 @@ def test_get_institutions_authed( assert res.status_code == 200 assert res.json()[0].get("name") == "Test Bank 123" + def test_get_institutions_authed_not_admin( + self, + mocker: MockerFixture, + app_fixture: FastAPI, + auth_mock: Mock, + ): + claims = { + "name": "test", + "preferred_username": "test_user", + "email": "test@local.host", + "sub": "testuser123", + } + auth_mock.return_value = ( + AuthCredentials(["manage-account", "authenticated"]), + AuthenticatedUser.from_claim(claims), + ) + client = TestClient(app_fixture) + res = client.get("/v1/institutions/") + assert res.status_code == 403 + def test_create_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) diff --git a/tests/app/conftest.py b/tests/app/conftest.py index 18cc99d..1a72ea3 100644 --- a/tests/app/conftest.py +++ b/tests/app/conftest.py @@ -1,6 +1,10 @@ +from typing import Tuple +from fastapi import Request import pytest from pytest_mock import MockerFixture +from starlette.authentication import AuthCredentials +from regtech_api_commons.models import AuthenticatedUser, RegTechUser @pytest.fixture(autouse=True) @@ -10,3 +14,26 @@ def setup(mocker: MockerFixture): mocked_engine.return_value = MockedEngine.return_value mocker.patch("fastapi.security.OAuth2AuthorizationCodeBearer") mocker.patch("entities.engine.get_session") + + +@pytest.fixture +def mock_auth() -> Tuple[AuthCredentials, RegTechUser]: + creds = AuthCredentials(["manage-account", "authenticated"]) + user = AuthenticatedUser.from_claim( + { + "name": "test", + "preferred_username": "test_user", + "email": "test@local.host", + "sub": "testuser123", + "institutions": ["TESTBANK123"], + } + ) + return creds, user + + +@pytest.fixture +def mock_request(mocker: MockerFixture, mock_auth: Tuple[AuthCredentials, RegTechUser]) -> Request: + request: Request = mocker.patch("fastapi.Request").return_value + request.auth = mock_auth[0] + request.user = mock_auth[1] + return request diff --git a/tests/app/test_dependencies.py b/tests/app/test_dependencies.py index f44692a..df47aa3 100644 --- a/tests/app/test_dependencies.py +++ b/tests/app/test_dependencies.py @@ -1,5 +1,10 @@ +from http import HTTPStatus +from typing import List +from fastapi import HTTPException, Request from pytest_mock import MockerFixture from sqlalchemy.ext.asyncio import AsyncSession +from dependencies import lei_association_check, fi_search_association_check +from starlette.authentication import AuthCredentials import pytest @@ -29,3 +34,111 @@ async def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession) assert await email_domain_denied(mock_session, allowed_domain) is False domain_allowed_mock.assert_called_once_with(mock_session, allowed_domain) + + +async def test_lei_association_check_matching_lei(mock_request: Request): + @lei_association_check + async def method_to_wrap(request: Request, lei: str): + pass + + await method_to_wrap(mock_request, lei="TESTBANK123") + + +async def test_lei_association_check_is_admin(mock_request: Request): + mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]) + + @lei_association_check + async def method_to_wrap(request: Request, lei: str): + pass + + await method_to_wrap(mock_request, lei="TESTBANK1234") + + +async def test_lei_association_check_not_matching(mock_request: Request): + @lei_association_check + async def method_to_wrap(request: Request, lei: str): + pass + + with pytest.raises(HTTPException) as e: + await method_to_wrap(mock_request, lei="NOTMYBANK") + assert e.value.status_code == HTTPStatus.FORBIDDEN + assert "not associated" in e.value.detail + + +async def test_fi_search_association_check_matching_lei(mock_request: Request): + @fi_search_association_check + async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): + pass + + await method_to_wrap(mock_request, leis=["TESTBANK123"]) + + +async def test_fi_search_association_check_invalid_lei(mock_request: Request): + @fi_search_association_check + async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): + pass + + with pytest.raises(HTTPException) as e: + await method_to_wrap(mock_request, leis=["NOTMYBANK"]) + assert e.value.status_code == HTTPStatus.FORBIDDEN + assert "not associated" in e.value.detail + + +async def test_fi_search_association_check_matching_domain(mock_request: Request): + @fi_search_association_check + async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): + pass + + await method_to_wrap(mock_request, domain="local.host") + + +async def test_fi_search_association_check_invalid_domain(mock_request: Request): + @fi_search_association_check + async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): + pass + + with pytest.raises(HTTPException) as e: + await method_to_wrap(mock_request, domain="not.myhost") + assert e.value.status_code == HTTPStatus.FORBIDDEN + assert "not associated" in e.value.detail + + +async def test_fi_search_association_check_no_filter(mock_request: Request): + @fi_search_association_check + async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): + pass + + with pytest.raises(HTTPException) as e: + await method_to_wrap(mock_request) + assert e.value.status_code == HTTPStatus.FORBIDDEN + assert "without filter" in e.value.detail + + +async def test_fi_search_association_check_lei_admin(mock_request: Request): + mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]) + + @fi_search_association_check + async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): + pass + + await method_to_wrap(mock_request, leis=["TESTBANK123", "ANOTHERBANK", "NOTMYBANK"]) + + +async def test_fi_search_association_check_domain_admin(mock_request: Request): + mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]) + + @fi_search_association_check + async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): + pass + + await method_to_wrap(mock_request, domain="not.myhost") + + +async def test_fi_search_association_check_no_filter_admin(mock_request: Request): + mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]) + + @fi_search_association_check + async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""): + pass + + await method_to_wrap(mock_request)