From 29f1137c782b21ff0a1288a49a8263cf50190eb6 Mon Sep 17 00:00:00 2001 From: Justin McGuffee Date: Fri, 15 Sep 2023 09:55:44 -0500 Subject: [PATCH 1/2] Added endpoint to return lei data from list of leis (#20) Added endpoint to return lei data from request containing list of leis. Added repo to gather data from db. Added simple error handling for when none of leis from list are found. --------- Co-authored-by: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> --- src/dependencies.py | 19 +++++++++++ src/entities/repos/institutions_repo.py | 10 ++++-- src/routers/institutions.py | 6 +++- .../entities/repos/test_institutions_repo.py | 33 +++++++++++++++---- 4 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/dependencies.py b/src/dependencies.py index 384c7c9..e9e178d 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -2,6 +2,10 @@ from typing import Annotated from fastapi import Depends, HTTPException, Request from sqlalchemy.ext.asyncio import AsyncSession +from typing import List, Optional +from itertools import chain + +from fastapi import Query from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -35,3 +39,18 @@ def request_needs_domain_check(request: Request) -> bool: async def email_domain_denied(session: AsyncSession, email: str) -> bool: return not await repo.is_email_domain_allowed(session, email) + + +def parse_leis(leis: List[str] = Query(None)) -> Optional[List]: + """ + Parses leis from list of one or multiple strings to a list of + multiple distinct lei strings. + Returns empty list when nothing is passed in + Ex1: ['lei1,lei2'] -> ['lei1', 'lei2'] + Ex2: ['lei1,lei2', 'lei3,lei4'] -> ['lei1','lei2','lei3','lei4'] + """ + + if leis: + return list(chain.from_iterable([x.split(",") for x in leis])) + else: + return None diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 887809b..fb7ef88 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -14,7 +14,11 @@ async def get_institutions( - session: AsyncSession, domain: str = "", page: int = 0, count: int = 100 + session: AsyncSession, + leis: List[str] = None, + domain: str = "", + page: int = 0, + count: int = 100, ) -> List[FinancialInstitutionDao]: async with session.begin(): stmt = ( @@ -23,7 +27,9 @@ async def get_institutions( .limit(count) .offset(page * count) ) - if d := domain.strip(): + if leis: + stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis)) + elif d := domain.strip(): search = "%{}%".format(d) stmt = stmt.join( FinancialInstitutionDomainDao, diff --git a/src/routers/institutions.py b/src/routers/institutions.py index e54ed32..440ca13 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -2,6 +2,7 @@ from http import HTTPStatus from oauth2 import oauth2_admin from util import Router +from dependencies import parse_leis from typing import Annotated, List, Tuple from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -28,11 +29,14 @@ async def set_db( @requires("authenticated") async def get_institutions( request: Request, + leis: List[str] = Depends(parse_leis), domain: str = "", page: int = 0, count: int = 100, ): - return await repo.get_institutions(request.state.db_session, domain, page, count) + return await repo.get_institutions( + request.state.db_session, leis, domain, page, count + ) @router.post("/", response_model=Tuple[str, FinancialInstitutionDto]) diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index f042c7f..584fd31 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -16,22 +16,29 @@ async def setup( self, transaction_session: AsyncSession, ): - fi_dao = FinancialInstitutionDao( + fi_dao_123, fi_dao_456 = FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") + FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123") + ], + ), FinancialInstitutionDao( + name="Test Bank 456", + lei="TESTBANK456", + domains=[ + FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456") ], ) - transaction_session.add(fi_dao) + transaction_session.add(fi_dao_123) + transaction_session.add(fi_dao_456) await transaction_session.commit() async def test_get_institutions(self, query_session: AsyncSession): res = await repo.get_institutions(query_session) - assert len(res) == 1 + assert len(res) == 2 async def test_get_institutions_by_domain(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, domain="test.bank") + res = await repo.get_institutions(query_session, domain="test.bank.1") assert len(res) == 1 async def test_get_institutions_by_domain_not_existing( @@ -40,13 +47,25 @@ async def test_get_institutions_by_domain_not_existing( res = await repo.get_institutions(query_session, domain="testing.bank") assert len(res) == 0 + async def test_get_institutions_by_lei_list(self, query_session: AsyncSession): + res = await repo.get_institutions( + query_session, leis=["TESTBANK123", "TESTBANK456"] + ) + assert len(res) == 2 + + async def test_get_institutions_by_lei_list_item_not_existing( + self, query_session: AsyncSession + ): + res = await repo.get_institutions(query_session, leis=["NOTTESTBANK"]) + assert len(res) == 0 + async def test_add_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( transaction_session, FinancialInstitutionDao(name="New Bank 123", lei="NEWBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 2 + assert len(res) == 3 async def test_update_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( @@ -54,7 +73,7 @@ async def test_update_institution(self, transaction_session: AsyncSession): FinancialInstitutionDao(name="Test Bank 234", lei="TESTBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 1 + assert len(res) == 2 assert res[0].name == "Test Bank 234" async def test_add_domains( From 342244c16bd2fdde594b6d31a394636d64609aae Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:49:10 -0700 Subject: [PATCH 2/2] Feature/30 gh actions update (#32) closes #30 all the python file changes are from black reformatting based on new line-length updates. --- .github/workflows/linters.yml | 6 +- pyproject.toml | 32 ++++--- src/dependencies.py | 13 +-- src/entities/engine/engine.py | 8 +- src/entities/models/dao.py | 4 +- src/entities/repos/institutions_repo.py | 8 +- src/main.py | 20 +--- src/oauth2/oauth2_admin.py | 12 +-- src/oauth2/oauth2_backend.py | 8 +- src/routers/institutions.py | 8 +- src/util/router_wrapper.py | 4 +- tests/api/conftest.py | 4 +- tests/api/routers/test_admin_api.py | 24 ++--- tests/api/routers/test_institutions_api.py | 92 +++++-------------- tests/entities/conftest.py | 4 +- .../entities/repos/test_institutions_repo.py | 34 ++----- 16 files changed, 82 insertions(+), 199 deletions(-) diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml index 3387d0b..2f24298 100644 --- a/.github/workflows/linters.yml +++ b/.github/workflows/linters.yml @@ -3,9 +3,13 @@ name: Linters on: [push] jobs: - linting: + black: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: psf/black@stable + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 - uses: chartboost/ruff-action@v1 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fc88790..d2412b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.1.0" description = "" authors = ["lchen-2101 <73617864+lchen-2101@users.noreply.github.com>"] readme = "README.md" -packages = [{include = "regtech-user-fi-management"}] +packages = [{ include = "regtech-user-fi-management" }] [tool.poetry.dependencies] python = "^3.11" @@ -30,22 +30,24 @@ pytest-mock = "^3.11.1" [tool.pytest.ini_options] asyncio_mode = "auto" -pythonpath = [ - "src" -] +pythonpath = ["src"] addopts = [ - "--cov-report=term-missing", - "--cov-branch", - "--cov-report=xml", - "--cov-report=term", - "--cov=src", - "-vv", - "--strict-markers", - "-rfE", -] -testpaths = [ - "tests", + "--cov-report=term-missing", + "--cov-branch", + "--cov-report=xml", + "--cov-report=term", + "--cov=src", + "-vv", + "--strict-markers", + "-rfE", ] +testpaths = ["tests"] + +[tool.black] +line-length = 120 + +[tool.ruff] +line-length = 120 [tool.coverage.run] relative_files = true diff --git a/src/dependencies.py b/src/dependencies.py index e9e178d..02ddcb0 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -17,24 +17,17 @@ } -async def check_domain( - request: Request, session: Annotated[AsyncSession, Depends(get_session)] -) -> None: +async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None: if request_needs_domain_check(request): if not request.user.is_authenticated: raise HTTPException(status_code=HTTPStatus.FORBIDDEN) if await email_domain_denied(session, request.user.email): - raise HTTPException( - status_code=HTTPStatus.FORBIDDEN, detail="email domain denied" - ) + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied") def request_needs_domain_check(request: Request) -> bool: path = request.scope["path"].rstrip("/") - return not ( - path in OPEN_DOMAIN_REQUESTS - and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path] - ) + return not (path in OPEN_DOMAIN_REQUESTS and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path]) async def email_domain_denied(session: AsyncSession, email: str) -> bool: diff --git a/src/entities/engine/engine.py b/src/entities/engine/engine.py index f1b0082..d37c1b1 100644 --- a/src/entities/engine/engine.py +++ b/src/entities/engine/engine.py @@ -8,12 +8,8 @@ DB_URL = os.getenv("INST_CONN") DB_SCHEMA = os.getenv("INST_DB_SCHEMA", "public") -engine = create_async_engine(DB_URL, echo=True).execution_options( - schema_translate_map={None: DB_SCHEMA} -) -SessionLocal = async_scoped_session( - async_sessionmaker(engine, expire_on_commit=False), current_task -) +engine = create_async_engine(DB_URL, echo=True).execution_options(schema_translate_map={None: DB_SCHEMA}) +SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) async def get_session(): diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 3059bd1..3213524 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -23,9 +23,7 @@ class FinancialInstitutionDao(AuditMixin, Base): class FinancialInstitutionDomainDao(AuditMixin, Base): __tablename__ = "financial_institution_domains" domain: Mapped[str] = mapped_column(index=True, primary_key=True) - lei: Mapped[str] = mapped_column( - ForeignKey("financial_institutions.lei"), index=True, primary_key=True - ) + lei: Mapped[str] = mapped_column(ForeignKey("financial_institutions.lei"), index=True, primary_key=True) fi = relationship("FinancialInstitutionDao", back_populates="domains") diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index fb7ef88..ba33cf1 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -49,13 +49,9 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti return await session.scalar(stmt) -async def upsert_institution( - session: AsyncSession, fi: FinancialInstitutionDto -) -> FinancialInstitutionDao: +async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao: async with session.begin(): - stmt = select(FinancialInstitutionDao).filter( - FinancialInstitutionDao.lei == fi.lei - ) + stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei) res = await session.execute(stmt) db_fi = res.scalar_one_or_none() if db_fi is None: diff --git a/src/main.py b/src/main.py index 3fd265a..bc54b7a 100644 --- a/src/main.py +++ b/src/main.py @@ -19,19 +19,13 @@ @app.exception_handler(HTTPException) -async def http_exception_handler( - request: Request, exception: HTTPException -) -> JSONResponse: +async def http_exception_handler(request: Request, exception: HTTPException) -> JSONResponse: log.error(exception, exc_info=True, stack_info=True) - return JSONResponse( - status_code=exception.status_code, content={"detail": exception.detail} - ) + return JSONResponse(status_code=exception.status_code, content={"detail": exception.detail}) @app.exception_handler(Exception) -async def general_exception_handler( - request: Request, exception: Exception -) -> JSONResponse: +async def general_exception_handler(request: Request, exception: Exception) -> JSONResponse: log.error(exception, exc_info=True, stack_info=True) return JSONResponse( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, @@ -39,13 +33,9 @@ async def general_exception_handler( ) -oauth2_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL") -) +oauth2_scheme = OAuth2AuthorizationCodeBearer(authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL")) -app.add_middleware( - AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme) -) +app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme)) app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/src/oauth2/oauth2_admin.py b/src/oauth2/oauth2_admin.py index 0b473d1..94fb16e 100644 --- a/src/oauth2/oauth2_admin.py +++ b/src/oauth2/oauth2_admin.py @@ -50,9 +50,7 @@ def update_user(self, user_id: str, payload: Dict[str, Any]) -> None: self._admin.update_user(user_id, payload) except kce.KeycloakError as e: log.exception("Failed to update user: %s", user_id, extra=payload) - raise HTTPException( - status_code=e.response_code, detail="Failed to update user" - ) + raise HTTPException(status_code=e.response_code, detail="Failed to update user") def upsert_group(self, lei: str, name: str) -> str: try: @@ -65,9 +63,7 @@ def upsert_group(self, lei: str, name: str) -> str: return group["id"] except kce.KeycloakError as e: log.exception("Failed to upsert group, lei: %s, name: %s", lei, name) - raise HTTPException( - status_code=e.response_code, detail="Failed to upsert group" - ) + raise HTTPException(status_code=e.response_code, detail="Failed to upsert group") def get_group(self, lei: str) -> Dict[str, Any] | None: try: @@ -80,9 +76,7 @@ def associate_to_group(self, user_id: str, group_id: str) -> None: self._admin.group_user_add(user_id, group_id) except kce.KeycloakError as e: log.exception("Failed to associate user %s to group %s", user_id, group_id) - raise HTTPException( - status_code=e.response_code, detail="Failed to associate user to group" - ) + raise HTTPException(status_code=e.response_code, detail="Failed to associate user to group") def associate_to_lei(self, user_id: str, lei: str) -> None: group = self.get_group(lei) diff --git a/src/oauth2/oauth2_backend.py b/src/oauth2/oauth2_backend.py index 516f222..22c682f 100644 --- a/src/oauth2/oauth2_backend.py +++ b/src/oauth2/oauth2_backend.py @@ -42,17 +42,13 @@ class BearerTokenAuthBackend(AuthenticationBackend): def __init__(self, token_bearer: OAuth2AuthorizationCodeBearer) -> None: self.token_bearer = token_bearer - async def authenticate( - self, conn: HTTPConnection - ) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]: + async def authenticate(self, conn: HTTPConnection) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]: try: token = await self.token_bearer(conn) claims = oauth2_admin.get_claims(token) if claims is not None: auths = ( - self.extract_nested( - claims, "resource_access", "realm-management", "roles" - ) + self.extract_nested(claims, "resource_access", "realm-management", "roles") + self.extract_nested(claims, "resource_access", "account", "roles") + ["authenticated"] ) diff --git a/src/routers/institutions.py b/src/routers/institutions.py index 440ca13..1de7b94 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -16,9 +16,7 @@ from starlette.authentication import requires -async def set_db( - request: Request, session: Annotated[AsyncSession, Depends(get_session)] -): +async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]): request.state.db_session = session @@ -34,9 +32,7 @@ async def get_institutions( page: int = 0, count: int = 100, ): - return await repo.get_institutions( - request.state.db_session, leis, domain, page, count - ) + return await repo.get_institutions(request.state.db_session, leis, domain, page, count) @router.post("/", response_model=Tuple[str, FinancialInstitutionDto]) diff --git a/src/util/router_wrapper.py b/src/util/router_wrapper.py index 8efd137..b3cc6f4 100644 --- a/src/util/router_wrapper.py +++ b/src/util/router_wrapper.py @@ -11,9 +11,7 @@ def api_route( if path.endswith("/"): path = path[:-1] - add_path = super().api_route( - path, include_in_schema=include_in_schema, **kwargs - ) + add_path = super().api_route(path, include_in_schema=include_in_schema, **kwargs) add_alt_path = super().api_route(f"{path}/", include_in_schema=False, **kwargs) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 148a35c..54ef356 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -35,9 +35,7 @@ def authed_user_mock(auth_mock: Mock) -> Mock: "sub": "testuser123", } auth_mock.return_value = ( - AuthCredentials( - ["manage-account", "query-groups", "manage-users", "authenticated"] - ), + AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]), AuthenticatedUser.from_claim(claims), ) return auth_mock diff --git a/tests/api/routers/test_admin_api.py b/tests/api/routers/test_admin_api.py index c973e1a..c2b1ee2 100644 --- a/tests/api/routers/test_admin_api.py +++ b/tests/api/routers/test_admin_api.py @@ -14,9 +14,7 @@ def test_get_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): res = client.get("/v1/admin/me") assert res.status_code == 403 - def test_get_me_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): + def test_get_me_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): client = TestClient(app_fixture) res = client.get("/v1/admin/me") assert res.status_code == 200 @@ -24,9 +22,7 @@ def test_get_me_authed( def test_update_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) - res = client.put( - "/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"} - ) + res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}) assert res.status_code == 403 def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): @@ -41,14 +37,10 @@ def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): AuthenticatedUser.from_claim(claims), ) client = TestClient(app_fixture) - res = client.put( - "/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"} - ) + res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}) assert res.status_code == 403 - def test_update_me( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): + def test_update_me(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): update_user_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.update_user") update_user_mock.return_value = None client = TestClient(app_fixture) @@ -57,12 +49,8 @@ def test_update_me( update_user_mock.assert_called_once_with("testuser123", data) assert res.status_code == 202 - def test_associate_institutions( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): - associate_lei_mock = mocker.patch( - "oauth2.oauth2_admin.OAuth2Admin.associate_to_lei" - ) + def test_associate_institutions(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + associate_lei_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.associate_to_lei") associate_lei_mock.return_value = None client = TestClient(app_fixture) data = ["testlei1", "testlei2"] diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index a666e44..7130437 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -9,26 +9,18 @@ class TestInstitutionsApi: - def test_get_institutions_unauthed( - self, app_fixture: FastAPI, unauthed_user_mock: Mock - ): + def test_get_institutions_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) res = client.get("/v1/institutions/") assert res.status_code == 403 - def test_get_institutions_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): - get_institutions_mock = mocker.patch( - "entities.repos.institutions_repo.get_institutions" - ) + def test_get_institutions_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + get_institutions_mock = mocker.patch("entities.repos.institutions_repo.get_institutions") get_institutions_mock.return_value = [ FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], ) ] client = TestClient(app_fixture) @@ -36,40 +28,26 @@ def test_get_institutions_authed( assert res.status_code == 200 assert res.json()[0].get("name") == "Test Bank 123" - def test_create_institution_unauthed( - self, app_fixture: FastAPI, unauthed_user_mock: Mock - ): + def test_create_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) - res = client.post( - "/v1/institutions/", json={"name": "testName", "lei": "testLei"} - ) + res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) assert res.status_code == 403 - def test_create_institution_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): - upsert_institution_mock = mocker.patch( - "entities.repos.institutions_repo.upsert_institution" - ) + def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + upsert_institution_mock = mocker.patch("entities.repos.institutions_repo.upsert_institution") upsert_institution_mock.return_value = FinancialInstitutionDao( name="testName", lei="testLei", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], ) upsert_group_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.upsert_group") upsert_group_mock.return_value = "leiGroup" client = TestClient(app_fixture) - res = client.post( - "/v1/institutions/", json={"name": "testName", "lei": "testLei"} - ) + res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) assert res.status_code == 200 assert res.json()[1].get("name") == "testName" - def test_create_institution_authed_no_permission( - self, app_fixture: FastAPI, auth_mock: Mock - ): + def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): claims = { "name": "test", "preferred_username": "test_user", @@ -81,31 +59,21 @@ def test_create_institution_authed_no_permission( AuthenticatedUser.from_claim(claims), ) client = TestClient(app_fixture) - res = client.post( - "/v1/institutions/", json={"name": "testName", "lei": "testLei"} - ) + res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) assert res.status_code == 403 - def test_get_institution_unauthed( - self, app_fixture: FastAPI, unauthed_user_mock: Mock - ): + def test_get_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) lei_path = "testLeiPath" res = client.get(f"/v1/institutions/{lei_path}") assert res.status_code == 403 - def test_get_institution_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): - get_institution_mock = mocker.patch( - "entities.repos.institutions_repo.get_institution" - ) + def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution") get_institution_mock.return_value = FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], ) client = TestClient(app_fixture) lei_path = "testLeiPath" @@ -117,30 +85,20 @@ def test_add_domains_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mo client = TestClient(app_fixture) lei_path = "testLeiPath" - res = client.post( - f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] - ) + res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 403 - def test_add_domains_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): + def test_add_domains_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): add_domains_mock = mocker.patch("entities.repos.institutions_repo.add_domains") - add_domains_mock.return_value = [ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ] + add_domains_mock.return_value = [FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")] client = TestClient(app_fixture) lei_path = "testLeiPath" - res = client.post( - f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] - ) + res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 200 assert res.json()[0].get("domain") == "test.bank" - def test_add_domains_authed_no_permission( - self, app_fixture: FastAPI, auth_mock: Mock - ): + def test_add_domains_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): claims = { "name": "test", "preferred_username": "test_user", @@ -153,9 +111,7 @@ def test_add_domains_authed_no_permission( ) client = TestClient(app_fixture) lei_path = "testLeiPath" - res = client.post( - f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] - ) + res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 403 def test_add_domains_authed_with_denied_email_domain( @@ -165,8 +121,6 @@ def test_add_domains_authed_with_denied_email_domain( domain_denied_mock.return_value = True client = TestClient(app_fixture) lei_path = "testLeiPath" - res = client.post( - f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] - ) + res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 403 assert "domain denied" in res.json()["detail"] diff --git a/tests/entities/conftest.py b/tests/entities/conftest.py index 990d971..fc0633d 100644 --- a/tests/entities/conftest.py +++ b/tests/entities/conftest.py @@ -58,6 +58,4 @@ async def query_session(session_generator: async_scoped_session): @pytest.fixture(scope="function") def session_generator(engine: AsyncEngine): - return async_scoped_session( - async_sessionmaker(engine, expire_on_commit=False), current_task - ) + return async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index 584fd31..5acdfd6 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -19,15 +19,11 @@ async def setup( fi_dao_123, fi_dao_456 = FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123")], ), FinancialInstitutionDao( name="Test Bank 456", lei="TESTBANK456", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456")], ) transaction_session.add(fi_dao_123) transaction_session.add(fi_dao_456) @@ -41,21 +37,15 @@ async def test_get_institutions_by_domain(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, domain="test.bank.1") assert len(res) == 1 - async def test_get_institutions_by_domain_not_existing( - self, query_session: AsyncSession - ): + async def test_get_institutions_by_domain_not_existing(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, domain="testing.bank") assert len(res) == 0 async def test_get_institutions_by_lei_list(self, query_session: AsyncSession): - res = await repo.get_institutions( - query_session, leis=["TESTBANK123", "TESTBANK456"] - ) + res = await repo.get_institutions(query_session, leis=["TESTBANK123", "TESTBANK456"]) assert len(res) == 2 - async def test_get_institutions_by_lei_list_item_not_existing( - self, query_session: AsyncSession - ): + async def test_get_institutions_by_lei_list_item_not_existing(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, leis=["NOTTESTBANK"]) assert len(res) == 0 @@ -76,9 +66,7 @@ async def test_update_institution(self, transaction_session: AsyncSession): assert len(res) == 2 assert res[0].name == "Test Bank 234" - async def test_add_domains( - self, transaction_session: AsyncSession, query_session: AsyncSession - ): + async def test_add_domains(self, transaction_session: AsyncSession, query_session: AsyncSession): await repo.add_domains( transaction_session, "TESTBANK123", @@ -91,11 +79,5 @@ async def test_domain_allowed(self, transaction_session: AsyncSession): denied_domain = DeniedDomainDao(domain="yahoo.com") transaction_session.add(denied_domain) await transaction_session.commit() - assert ( - await repo.is_email_domain_allowed(transaction_session, "test@yahoo.com") - is False - ) - assert ( - await repo.is_email_domain_allowed(transaction_session, "test@gmail.com") - is True - ) + assert await repo.is_email_domain_allowed(transaction_session, "test@yahoo.com") is False + assert await repo.is_email_domain_allowed(transaction_session, "test@gmail.com") is True