From 9b840aa25a6549a9e78a7d05c1b6459883cd0e43 Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Thu, 14 Sep 2023 15:48:49 -0400 Subject: [PATCH] feat: refactor user model, and add institutions attribute --- src/entities/models/__init__.py | 2 ++ src/entities/models/dto.py | 35 ++++++++++++++++++++++++++++- src/oauth2/__init__.py | 4 ++-- src/oauth2/oauth2_backend.py | 25 ++------------------- src/routers/admin.py | 3 ++- tests/api/conftest.py | 2 +- tests/api/routers/test_admin_api.py | 26 ++++++++++++++++++--- 7 files changed, 66 insertions(+), 31 deletions(-) diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py index 6dfd58b..c772155 100644 --- a/src/entities/models/__init__.py +++ b/src/entities/models/__init__.py @@ -8,6 +8,7 @@ "FinancialInsitutionDomainCreate", "DeniedDomainDao", "DeniedDomainDto", + "AuthenticatedUser", ] from .dao import ( @@ -22,4 +23,5 @@ FinancialInsitutionDomainDto, FinancialInsitutionDomainCreate, DeniedDomainDto, + AuthenticatedUser, ) diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py index bc84375..d6c648a 100644 --- a/src/entities/models/dto.py +++ b/src/entities/models/dto.py @@ -1,5 +1,7 @@ -from typing import List +from typing import Any, Dict, List + from pydantic import BaseModel +from starlette.authentication import BaseUser class FinancialInsitutionDomainBase(BaseModel): @@ -37,3 +39,34 @@ class DeniedDomainDto(BaseModel): class Config: orm_mode = True + + +class AuthenticatedUser(BaseUser, BaseModel): + claims: Dict[str, Any] + name: str + username: str + email: str + id: str + institutions: List[str] + + @classmethod + def from_claim(cls, claims: Dict[str, Any]) -> "AuthenticatedUser": + return cls( + claims=claims, + name=claims.get("name", ""), + username=claims.get("preferred_username", ""), + email=claims.get("email", ""), + id=claims.get("sub", ""), + institutions=cls.parse_institutions(claims.get("institutions")), + ) + + @classmethod + def parse_institutions(cls, institutions: List[str] | None) -> List[str]: + if institutions: + return list(map(lambda institution: institution.lstrip("/"), institutions)) + else: + return [] + + @property + def is_authenticated(self) -> bool: + return True diff --git a/src/oauth2/__init__.py b/src/oauth2/__init__.py index ec2dc7f..94759fc 100644 --- a/src/oauth2/__init__.py +++ b/src/oauth2/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["oauth2_admin", "BearerTokenAuthBackend", "AuthenticatedUser"] +__all__ = ["oauth2_admin", "BearerTokenAuthBackend"] from .oauth2_admin import oauth2_admin -from .oauth2_backend import BearerTokenAuthBackend, AuthenticatedUser +from .oauth2_backend import BearerTokenAuthBackend diff --git a/src/oauth2/oauth2_backend.py b/src/oauth2/oauth2_backend.py index 516f222..605f3f4 100644 --- a/src/oauth2/oauth2_backend.py +++ b/src/oauth2/oauth2_backend.py @@ -1,7 +1,6 @@ import logging from typing import Coroutine, Any, Dict, List, Tuple from fastapi import HTTPException -from pydantic import BaseModel from starlette.authentication import ( AuthCredentials, AuthenticationBackend, @@ -11,33 +10,13 @@ from fastapi.security import OAuth2AuthorizationCodeBearer from starlette.requests import HTTPConnection +from entities.models import AuthenticatedUser + from .oauth2_admin import oauth2_admin log = logging.getLogger(__name__) -class AuthenticatedUser(BaseUser, BaseModel): - claims: Dict[str, Any] - name: str | None - username: str | None - email: str | None - id: str | None - - @classmethod - def from_claim(cls, claims: Dict[str, Any]) -> "AuthenticatedUser": - return cls( - claims=claims, - name=claims.get("name"), - username=claims.get("preferred_username"), - email=claims.get("email"), - id=claims.get("sub"), - ) - - @property - def is_authenticated(self) -> bool: - return True - - class BearerTokenAuthBackend(AuthenticationBackend): def __init__(self, token_bearer: OAuth2AuthorizationCodeBearer) -> None: self.token_bearer = token_bearer diff --git a/src/routers/admin.py b/src/routers/admin.py index 676dc1e..abf7467 100644 --- a/src/routers/admin.py +++ b/src/routers/admin.py @@ -4,7 +4,8 @@ from starlette.authentication import requires from util import Router -from oauth2 import AuthenticatedUser, oauth2_admin +from entities.models import AuthenticatedUser +from oauth2 import oauth2_admin router = Router() diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 148a35c..557f8b6 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -5,7 +5,7 @@ from pytest_mock import MockerFixture from starlette.authentication import AuthCredentials, UnauthenticatedUser -from oauth2.oauth2_backend import AuthenticatedUser +from entities.models import AuthenticatedUser @pytest.fixture diff --git a/tests/api/routers/test_admin_api.py b/tests/api/routers/test_admin_api.py index c973e1a..d1b5f21 100644 --- a/tests/api/routers/test_admin_api.py +++ b/tests/api/routers/test_admin_api.py @@ -5,7 +5,7 @@ from pytest_mock import MockerFixture from starlette.authentication import AuthCredentials -from oauth2.oauth2_backend import AuthenticatedUser +from entities.models import AuthenticatedUser class TestAdminApi: @@ -14,13 +14,33 @@ def test_get_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): res = client.get("/v1/admin/me") assert res.status_code == 403 - def test_get_me_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock + def test_get_me_authed_with_no_institutions( + self, app_fixture: FastAPI, authed_user_mock: Mock ): client = TestClient(app_fixture) res = client.get("/v1/admin/me") assert res.status_code == 200 assert res.json().get("name") == "test" + assert res.json().get("institutions") == [] + + def test_get_me_authed_with_institutions( + self, app_fixture: FastAPI, auth_mock: Mock + ): + claims = { + "name": "test", + "preferred_username": "test_user", + "email": "test@local.host", + "sub": "testuser123", + "institutions": ["/TEST1LEI", "/TEST2LEI"], + } + auth_mock.return_value = ( + AuthCredentials(["authenticated"]), + AuthenticatedUser.from_claim(claims), + ) + client = TestClient(app_fixture) + res = client.get("/v1/admin/me") + assert res.status_code == 200 + assert res.json().get("institutions") == ["TEST1LEI", "TEST2LEI"] def test_update_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture)