Skip to content

Commit

Permalink
feat: restrict FI data retrieval (#120)
Browse files Browse the repository at this point in the history
closes #111
  • Loading branch information
lchen-2101 authored Mar 18, 2024
1 parent 236c19f commit b469f4a
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/.env.local
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ INST_DB_HOST=localhost:5432
INST_DB_SCHEMA=public
JWT_OPTS_VERIFY_AT_HASH="false"
JWT_OPTS_VERIFY_AUD="false"
JWT_OPTS_VERIFY_ISS="false"
JWT_OPTS_VERIFY_ISS="false"
ADMIN_SCOPES=["query-groups","manage-users"]
3 changes: 2 additions & 1 deletion src/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from urllib import parse
from typing import Any
from typing import Any, Set

from pydantic import field_validator, ValidationInfo
from pydantic.networks import PostgresDsn
Expand All @@ -24,6 +24,7 @@ class Settings(BaseSettings):
inst_db_host: str
inst_db_scheme: str = "postgresql+asyncpg"
inst_conn: PostgresDsn | None = None
admin_scopes: Set[str] = set(["query-groups", "manage-users"])

def __init__(self, **data):
super().__init__(**data)
Expand Down
60 changes: 57 additions & 3 deletions src/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import functools

from http import HTTPStatus
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from fastapi import Depends, Query, HTTPException, Request, Response
from fastapi.types import DecoratedCallable
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from itertools import chain

from fastapi import Query
from config import settings

from entities.engine import get_session
from entities.repos import institutions_repo as repo
from starlette.authentication import AuthCredentials
from regtech_api_commons.models.auth import AuthenticatedUser


async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
Expand Down Expand Up @@ -41,3 +45,53 @@ def get_email_domain(email: str) -> str:
if email:
return email.split("@")[-1]
return None


def is_admin(auth: AuthCredentials):
return settings.admin_scopes.issubset(auth.scopes)


def lei_association_check(func: DecoratedCallable) -> DecoratedCallable:
@functools.wraps(func)
async def wrapper(request: Request, *args, **kwargs) -> Response:
lei = kwargs.get("lei")
user: AuthenticatedUser = request.user
auth: AuthCredentials = request.auth
if not is_admin(auth) and lei not in user.institutions:
raise HTTPException(HTTPStatus.FORBIDDEN, detail=f"LEI {lei} is not associated with the user.")
return await func(request, *args, **kwargs)

return wrapper # type: ignore[return-value]


def fi_search_association_check(func: DecoratedCallable) -> DecoratedCallable:
def verify_leis(user: AuthenticatedUser, leis: List[str]) -> None:
if not set(filter(len, leis)).issubset(set(filter(len, user.institutions))):
raise HTTPException(
HTTPStatus.FORBIDDEN,
detail=f"Institutions query with LEIs ({leis}) not associated with user is forbidden.",
)

def verify_domain(user: AuthenticatedUser, domain: str) -> None:
if domain != get_email_domain(user.email):
raise HTTPException(
HTTPStatus.FORBIDDEN,
detail=f"Institutions query with domain ({domain}) not associated with user is forbidden.",
)

@functools.wraps(func)
async def wrapper(request: Request, *args, **kwargs) -> Response:
user: AuthenticatedUser = request.user
auth: AuthCredentials = request.auth
if not is_admin(auth):
leis = kwargs.get("leis")
domain = kwargs.get("domain")
if leis:
verify_leis(user, leis)
elif domain:
verify_domain(user, domain)
elif not leis and not domain:
raise HTTPException(HTTPStatus.FORBIDDEN, detail="Retrieving institutions without filter is forbidden.")
return await func(request=request, *args, **kwargs)

return wrapper # type: ignore[return-value]
6 changes: 5 additions & 1 deletion src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from regtech_api_commons.oauth2.oauth2_admin import OAuth2Admin
from config import kc_settings
from regtech_api_commons.api import Router
from dependencies import check_domain, parse_leis, get_email_domain
from dependencies import check_domain, parse_leis, get_email_domain, lei_association_check, fi_search_association_check
from typing import Annotated, List, Tuple, Literal
from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand Down Expand Up @@ -38,6 +38,7 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_

@router.get("/", response_model=List[FinancialInstitutionWithRelationsDto])
@requires("authenticated")
@fi_search_association_check
async def get_institutions(
request: Request,
leis: List[str] = Depends(parse_leis),
Expand Down Expand Up @@ -98,6 +99,7 @@ async def get_federal_regulators(request: Request):

@router.get("/{lei}", response_model=FinancialInstitutionWithRelationsDto)
@requires("authenticated")
@lei_association_check
async def get_institution(
request: Request,
lei: str,
Expand All @@ -110,6 +112,7 @@ async def get_institution(

@router.get("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
@lei_association_check
async def get_types(request: Request, response: Response, lei: str, type: InstitutionType):
match type:
case "sbl":
Expand All @@ -123,6 +126,7 @@ async def get_types(request: Request, response: Response, lei: str, type: Instit

@router.put("/{lei}/types/{type}", response_model=VersionedData[List[SblTypeAssociationDetailsDto]] | None)
@requires("authenticated")
@lei_association_check
async def update_types(
request: Request, response: Response, lei: str, type: InstitutionType, types_patch: SblTypeAssociationPatchDto
):
Expand Down
20 changes: 20 additions & 0 deletions tests/api/routers/test_institutions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ def test_get_institutions_authed(
assert res.status_code == 200
assert res.json()[0].get("name") == "Test Bank 123"

def test_get_institutions_authed_not_admin(
self,
mocker: MockerFixture,
app_fixture: FastAPI,
auth_mock: Mock,
):
claims = {
"name": "test",
"preferred_username": "test_user",
"email": "[email protected]",
"sub": "testuser123",
}
auth_mock.return_value = (
AuthCredentials(["manage-account", "authenticated"]),
AuthenticatedUser.from_claim(claims),
)
client = TestClient(app_fixture)
res = client.get("/v1/institutions/")
assert res.status_code == 403

def test_create_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)
res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"})
Expand Down
27 changes: 27 additions & 0 deletions tests/app/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Tuple
from fastapi import Request
import pytest

from pytest_mock import MockerFixture
from starlette.authentication import AuthCredentials
from regtech_api_commons.models import AuthenticatedUser, RegTechUser


@pytest.fixture(autouse=True)
Expand All @@ -10,3 +14,26 @@ def setup(mocker: MockerFixture):
mocked_engine.return_value = MockedEngine.return_value
mocker.patch("fastapi.security.OAuth2AuthorizationCodeBearer")
mocker.patch("entities.engine.get_session")


@pytest.fixture
def mock_auth() -> Tuple[AuthCredentials, RegTechUser]:
creds = AuthCredentials(["manage-account", "authenticated"])
user = AuthenticatedUser.from_claim(
{
"name": "test",
"preferred_username": "test_user",
"email": "[email protected]",
"sub": "testuser123",
"institutions": ["TESTBANK123"],
}
)
return creds, user


@pytest.fixture
def mock_request(mocker: MockerFixture, mock_auth: Tuple[AuthCredentials, RegTechUser]) -> Request:
request: Request = mocker.patch("fastapi.Request").return_value
request.auth = mock_auth[0]
request.user = mock_auth[1]
return request
113 changes: 113 additions & 0 deletions tests/app/test_dependencies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from http import HTTPStatus
from typing import List
from fastapi import HTTPException, Request
from pytest_mock import MockerFixture
from sqlalchemy.ext.asyncio import AsyncSession
from dependencies import lei_association_check, fi_search_association_check
from starlette.authentication import AuthCredentials

import pytest

Expand Down Expand Up @@ -29,3 +34,111 @@ async def test_domain_allowed(mocker: MockerFixture, mock_session: AsyncSession)

assert await email_domain_denied(mock_session, allowed_domain) is False
domain_allowed_mock.assert_called_once_with(mock_session, allowed_domain)


async def test_lei_association_check_matching_lei(mock_request: Request):
@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

await method_to_wrap(mock_request, lei="TESTBANK123")


async def test_lei_association_check_is_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

await method_to_wrap(mock_request, lei="TESTBANK1234")


async def test_lei_association_check_not_matching(mock_request: Request):
@lei_association_check
async def method_to_wrap(request: Request, lei: str):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, lei="NOTMYBANK")
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail


async def test_fi_search_association_check_matching_lei(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, leis=["TESTBANK123"])


async def test_fi_search_association_check_invalid_lei(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, leis=["NOTMYBANK"])
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail


async def test_fi_search_association_check_matching_domain(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, domain="local.host")


async def test_fi_search_association_check_invalid_domain(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request, domain="not.myhost")
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "not associated" in e.value.detail


async def test_fi_search_association_check_no_filter(mock_request: Request):
@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

with pytest.raises(HTTPException) as e:
await method_to_wrap(mock_request)
assert e.value.status_code == HTTPStatus.FORBIDDEN
assert "without filter" in e.value.detail


async def test_fi_search_association_check_lei_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, leis=["TESTBANK123", "ANOTHERBANK", "NOTMYBANK"])


async def test_fi_search_association_check_domain_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request, domain="not.myhost")


async def test_fi_search_association_check_no_filter_admin(mock_request: Request):
mock_request.auth = AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"])

@fi_search_association_check
async def method_to_wrap(request: Request, leis: List[str] = [], domain: str = ""):
pass

await method_to_wrap(mock_request)

0 comments on commit b469f4a

Please sign in to comment.