Skip to content

Commit

Permalink
Merge branch 'main' into feature/26_update_auth_user_model
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Sep 18, 2023
2 parents 1687cf1 + 342244c commit 61da079
Show file tree
Hide file tree
Showing 16 changed files with 128 additions and 199 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/linters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ name: Linters
on: [push]

jobs:
linting:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
32 changes: 17 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
description = ""
authors = ["lchen-2101 <[email protected]>"]
readme = "README.md"
packages = [{include = "regtech-user-fi-management"}]
packages = [{ include = "regtech-user-fi-management" }]

[tool.poetry.dependencies]
python = "^3.11"
Expand All @@ -30,22 +30,24 @@ pytest-mock = "^3.11.1"

[tool.pytest.ini_options]
asyncio_mode = "auto"
pythonpath = [
"src"
]
pythonpath = ["src"]
addopts = [
"--cov-report=term-missing",
"--cov-branch",
"--cov-report=xml",
"--cov-report=term",
"--cov=src",
"-vv",
"--strict-markers",
"-rfE",
]
testpaths = [
"tests",
"--cov-report=term-missing",
"--cov-branch",
"--cov-report=xml",
"--cov-report=term",
"--cov=src",
"-vv",
"--strict-markers",
"-rfE",
]
testpaths = ["tests"]

[tool.black]
line-length = 120

[tool.ruff]
line-length = 120

[tool.coverage.run]
relative_files = true
Expand Down
32 changes: 22 additions & 10 deletions src/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from itertools import chain

from fastapi import Query

from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand All @@ -13,25 +17,33 @@
}


async def check_domain(
request: Request, session: Annotated[AsyncSession, Depends(get_session)]
) -> None:
async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
if request_needs_domain_check(request):
if not request.user.is_authenticated:
raise HTTPException(status_code=HTTPStatus.FORBIDDEN)
if await email_domain_denied(session, request.user.email):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="email domain denied"
)
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied")


def request_needs_domain_check(request: Request) -> bool:
path = request.scope["path"].rstrip("/")
return not (
path in OPEN_DOMAIN_REQUESTS
and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path]
)
return not (path in OPEN_DOMAIN_REQUESTS and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path])


async def email_domain_denied(session: AsyncSession, email: str) -> bool:
return not await repo.is_email_domain_allowed(session, email)


def parse_leis(leis: List[str] = Query(None)) -> Optional[List]:
"""
Parses leis from list of one or multiple strings to a list of
multiple distinct lei strings.
Returns empty list when nothing is passed in
Ex1: ['lei1,lei2'] -> ['lei1', 'lei2']
Ex2: ['lei1,lei2', 'lei3,lei4'] -> ['lei1','lei2','lei3','lei4']
"""

if leis:
return list(chain.from_iterable([x.split(",") for x in leis]))
else:
return None
8 changes: 2 additions & 6 deletions src/entities/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@

DB_URL = os.getenv("INST_CONN")
DB_SCHEMA = os.getenv("INST_DB_SCHEMA", "public")
engine = create_async_engine(DB_URL, echo=True).execution_options(
schema_translate_map={None: DB_SCHEMA}
)
SessionLocal = async_scoped_session(
async_sessionmaker(engine, expire_on_commit=False), current_task
)
engine = create_async_engine(DB_URL, echo=True).execution_options(schema_translate_map={None: DB_SCHEMA})
SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task)


async def get_session():
Expand Down
4 changes: 1 addition & 3 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ class FinancialInstitutionDao(AuditMixin, Base):
class FinancialInstitutionDomainDao(AuditMixin, Base):
__tablename__ = "financial_institution_domains"
domain: Mapped[str] = mapped_column(index=True, primary_key=True)
lei: Mapped[str] = mapped_column(
ForeignKey("financial_institutions.lei"), index=True, primary_key=True
)
lei: Mapped[str] = mapped_column(ForeignKey("financial_institutions.lei"), index=True, primary_key=True)
fi = relationship("FinancialInstitutionDao", back_populates="domains")


Expand Down
18 changes: 10 additions & 8 deletions src/entities/repos/institutions_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@


async def get_institutions(
session: AsyncSession, domain: str = "", page: int = 0, count: int = 100
session: AsyncSession,
leis: List[str] = None,
domain: str = "",
page: int = 0,
count: int = 100,
) -> List[FinancialInstitutionDao]:
async with session.begin():
stmt = (
Expand All @@ -23,7 +27,9 @@ async def get_institutions(
.limit(count)
.offset(page * count)
)
if d := domain.strip():
if leis:
stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis))
elif d := domain.strip():
search = "%{}%".format(d)
stmt = stmt.join(
FinancialInstitutionDomainDao,
Expand All @@ -43,13 +49,9 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti
return await session.scalar(stmt)


async def upsert_institution(
session: AsyncSession, fi: FinancialInstitutionDto
) -> FinancialInstitutionDao:
async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao:
async with session.begin():
stmt = select(FinancialInstitutionDao).filter(
FinancialInstitutionDao.lei == fi.lei
)
stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei)
res = await session.execute(stmt)
db_fi = res.scalar_one_or_none()
if db_fi is None:
Expand Down
20 changes: 5 additions & 15 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,23 @@


@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request, exception: HTTPException
) -> JSONResponse:
async def http_exception_handler(request: Request, exception: HTTPException) -> JSONResponse:
log.error(exception, exc_info=True, stack_info=True)
return JSONResponse(
status_code=exception.status_code, content={"detail": exception.detail}
)
return JSONResponse(status_code=exception.status_code, content={"detail": exception.detail})


@app.exception_handler(Exception)
async def general_exception_handler(
request: Request, exception: Exception
) -> JSONResponse:
async def general_exception_handler(request: Request, exception: Exception) -> JSONResponse:
log.error(exception, exc_info=True, stack_info=True)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
content={"detail": "server error"},
)


oauth2_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL")
)
oauth2_scheme = OAuth2AuthorizationCodeBearer(authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL"))

app.add_middleware(
AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme)
)
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme))
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
12 changes: 3 additions & 9 deletions src/oauth2/oauth2_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def update_user(self, user_id: str, payload: Dict[str, Any]) -> None:
self._admin.update_user(user_id, payload)
except kce.KeycloakError as e:
log.exception("Failed to update user: %s", user_id, extra=payload)
raise HTTPException(
status_code=e.response_code, detail="Failed to update user"
)
raise HTTPException(status_code=e.response_code, detail="Failed to update user")

def upsert_group(self, lei: str, name: str) -> str:
try:
Expand All @@ -65,9 +63,7 @@ def upsert_group(self, lei: str, name: str) -> str:
return group["id"]
except kce.KeycloakError as e:
log.exception("Failed to upsert group, lei: %s, name: %s", lei, name)
raise HTTPException(
status_code=e.response_code, detail="Failed to upsert group"
)
raise HTTPException(status_code=e.response_code, detail="Failed to upsert group")

def get_group(self, lei: str) -> Dict[str, Any] | None:
try:
Expand All @@ -80,9 +76,7 @@ def associate_to_group(self, user_id: str, group_id: str) -> None:
self._admin.group_user_add(user_id, group_id)
except kce.KeycloakError as e:
log.exception("Failed to associate user %s to group %s", user_id, group_id)
raise HTTPException(
status_code=e.response_code, detail="Failed to associate user to group"
)
raise HTTPException(status_code=e.response_code, detail="Failed to associate user to group")

def associate_to_lei(self, user_id: str, lei: str) -> None:
group = self.get_group(lei)
Expand Down
8 changes: 2 additions & 6 deletions src/oauth2/oauth2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,13 @@ class BearerTokenAuthBackend(AuthenticationBackend):
def __init__(self, token_bearer: OAuth2AuthorizationCodeBearer) -> None:
self.token_bearer = token_bearer

async def authenticate(
self, conn: HTTPConnection
) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]:
async def authenticate(self, conn: HTTPConnection) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]:
try:
token = await self.token_bearer(conn)
claims = oauth2_admin.get_claims(token)
if claims is not None:
auths = (
self.extract_nested(
claims, "resource_access", "realm-management", "roles"
)
self.extract_nested(claims, "resource_access", "realm-management", "roles")
+ self.extract_nested(claims, "resource_access", "account", "roles")
+ ["authenticated"]
)
Expand Down
8 changes: 4 additions & 4 deletions src/routers/institutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from http import HTTPStatus
from oauth2 import oauth2_admin
from util import Router
from dependencies import parse_leis
from typing import Annotated, List, Tuple
from entities.engine import get_session
from entities.repos import institutions_repo as repo
Expand All @@ -15,9 +16,7 @@
from starlette.authentication import requires


async def set_db(
request: Request, session: Annotated[AsyncSession, Depends(get_session)]
):
async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]):
request.state.db_session = session


Expand All @@ -28,11 +27,12 @@ async def set_db(
@requires("authenticated")
async def get_institutions(
request: Request,
leis: List[str] = Depends(parse_leis),
domain: str = "",
page: int = 0,
count: int = 100,
):
return await repo.get_institutions(request.state.db_session, domain, page, count)
return await repo.get_institutions(request.state.db_session, leis, domain, page, count)


@router.post("/", response_model=Tuple[str, FinancialInstitutionDto])
Expand Down
4 changes: 1 addition & 3 deletions src/util/router_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def api_route(
if path.endswith("/"):
path = path[:-1]

add_path = super().api_route(
path, include_in_schema=include_in_schema, **kwargs
)
add_path = super().api_route(path, include_in_schema=include_in_schema, **kwargs)

add_alt_path = super().api_route(f"{path}/", include_in_schema=False, **kwargs)

Expand Down
4 changes: 1 addition & 3 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def authed_user_mock(auth_mock: Mock) -> Mock:
"sub": "testuser123",
}
auth_mock.return_value = (
AuthCredentials(
["manage-account", "query-groups", "manage-users", "authenticated"]
),
AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]),
AuthenticatedUser.from_claim(claims),
)
return auth_mock
Expand Down
28 changes: 7 additions & 21 deletions tests/api/routers/test_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@ def test_get_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
res = client.get("/v1/admin/me")
assert res.status_code == 403

def test_get_me_authed_with_no_institutions(
self, app_fixture: FastAPI, authed_user_mock: Mock
):
def test_get_me_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
client = TestClient(app_fixture)
res = client.get("/v1/admin/me")
assert res.status_code == 200
assert res.json().get("name") == "test"
assert res.json().get("institutions") == []

def test_get_me_authed_with_institutions(
self, app_fixture: FastAPI, auth_mock: Mock
):
def test_get_me_authed_with_institutions(self, app_fixture: FastAPI, auth_mock: Mock):
claims = {
"name": "test",
"preferred_username": "test_user",
Expand All @@ -44,9 +40,7 @@ def test_get_me_authed_with_institutions(

def test_update_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock):
client = TestClient(app_fixture)
res = client.put(
"/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}
)
res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"})
assert res.status_code == 403

def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock):
Expand All @@ -61,14 +55,10 @@ def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock):
AuthenticatedUser.from_claim(claims),
)
client = TestClient(app_fixture)
res = client.put(
"/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}
)
res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"})
assert res.status_code == 403

def test_update_me(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
def test_update_me(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
update_user_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.update_user")
update_user_mock.return_value = None
client = TestClient(app_fixture)
Expand All @@ -77,12 +67,8 @@ def test_update_me(
update_user_mock.assert_called_once_with("testuser123", data)
assert res.status_code == 202

def test_associate_institutions(
self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock
):
associate_lei_mock = mocker.patch(
"oauth2.oauth2_admin.OAuth2Admin.associate_to_lei"
)
def test_associate_institutions(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock):
associate_lei_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.associate_to_lei")
associate_lei_mock.return_value = None
client = TestClient(app_fixture)
data = ["testlei1", "testlei2"]
Expand Down
Loading

0 comments on commit 61da079

Please sign in to comment.