diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml index 3387d0b..2f24298 100644 --- a/.github/workflows/linters.yml +++ b/.github/workflows/linters.yml @@ -3,9 +3,13 @@ name: Linters on: [push] jobs: - linting: + black: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: psf/black@stable + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 - uses: chartboost/ruff-action@v1 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fc88790..d2412b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.1.0" description = "" authors = ["lchen-2101 <73617864+lchen-2101@users.noreply.github.com>"] readme = "README.md" -packages = [{include = "regtech-user-fi-management"}] +packages = [{ include = "regtech-user-fi-management" }] [tool.poetry.dependencies] python = "^3.11" @@ -30,22 +30,24 @@ pytest-mock = "^3.11.1" [tool.pytest.ini_options] asyncio_mode = "auto" -pythonpath = [ - "src" -] +pythonpath = ["src"] addopts = [ - "--cov-report=term-missing", - "--cov-branch", - "--cov-report=xml", - "--cov-report=term", - "--cov=src", - "-vv", - "--strict-markers", - "-rfE", -] -testpaths = [ - "tests", + "--cov-report=term-missing", + "--cov-branch", + "--cov-report=xml", + "--cov-report=term", + "--cov=src", + "-vv", + "--strict-markers", + "-rfE", ] +testpaths = ["tests"] + +[tool.black] +line-length = 120 + +[tool.ruff] +line-length = 120 [tool.coverage.run] relative_files = true diff --git a/src/dependencies.py b/src/dependencies.py index 384c7c9..02ddcb0 100644 --- a/src/dependencies.py +++ b/src/dependencies.py @@ -2,6 +2,10 @@ from typing import Annotated from fastapi import Depends, HTTPException, Request from sqlalchemy.ext.asyncio import AsyncSession +from typing import List, Optional +from itertools import chain + +from fastapi import Query from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -13,25 +17,33 @@ } -async def check_domain( - request: Request, session: Annotated[AsyncSession, Depends(get_session)] -) -> None: +async def check_domain(request: Request, session: Annotated[AsyncSession, Depends(get_session)]) -> None: if request_needs_domain_check(request): if not request.user.is_authenticated: raise HTTPException(status_code=HTTPStatus.FORBIDDEN) if await email_domain_denied(session, request.user.email): - raise HTTPException( - status_code=HTTPStatus.FORBIDDEN, detail="email domain denied" - ) + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="email domain denied") def request_needs_domain_check(request: Request) -> bool: path = request.scope["path"].rstrip("/") - return not ( - path in OPEN_DOMAIN_REQUESTS - and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path] - ) + return not (path in OPEN_DOMAIN_REQUESTS and request.scope["method"] in OPEN_DOMAIN_REQUESTS[path]) async def email_domain_denied(session: AsyncSession, email: str) -> bool: return not await repo.is_email_domain_allowed(session, email) + + +def parse_leis(leis: List[str] = Query(None)) -> Optional[List]: + """ + Parses leis from list of one or multiple strings to a list of + multiple distinct lei strings. + Returns empty list when nothing is passed in + Ex1: ['lei1,lei2'] -> ['lei1', 'lei2'] + Ex2: ['lei1,lei2', 'lei3,lei4'] -> ['lei1','lei2','lei3','lei4'] + """ + + if leis: + return list(chain.from_iterable([x.split(",") for x in leis])) + else: + return None diff --git a/src/entities/engine/engine.py b/src/entities/engine/engine.py index f1b0082..d37c1b1 100644 --- a/src/entities/engine/engine.py +++ b/src/entities/engine/engine.py @@ -8,12 +8,8 @@ DB_URL = os.getenv("INST_CONN") DB_SCHEMA = os.getenv("INST_DB_SCHEMA", "public") -engine = create_async_engine(DB_URL, echo=True).execution_options( - schema_translate_map={None: DB_SCHEMA} -) -SessionLocal = async_scoped_session( - async_sessionmaker(engine, expire_on_commit=False), current_task -) +engine = create_async_engine(DB_URL, echo=True).execution_options(schema_translate_map={None: DB_SCHEMA}) +SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) async def get_session(): diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py index 3059bd1..3213524 100644 --- a/src/entities/models/dao.py +++ b/src/entities/models/dao.py @@ -23,9 +23,7 @@ class FinancialInstitutionDao(AuditMixin, Base): class FinancialInstitutionDomainDao(AuditMixin, Base): __tablename__ = "financial_institution_domains" domain: Mapped[str] = mapped_column(index=True, primary_key=True) - lei: Mapped[str] = mapped_column( - ForeignKey("financial_institutions.lei"), index=True, primary_key=True - ) + lei: Mapped[str] = mapped_column(ForeignKey("financial_institutions.lei"), index=True, primary_key=True) fi = relationship("FinancialInstitutionDao", back_populates="domains") diff --git a/src/entities/repos/institutions_repo.py b/src/entities/repos/institutions_repo.py index 887809b..ba33cf1 100644 --- a/src/entities/repos/institutions_repo.py +++ b/src/entities/repos/institutions_repo.py @@ -14,7 +14,11 @@ async def get_institutions( - session: AsyncSession, domain: str = "", page: int = 0, count: int = 100 + session: AsyncSession, + leis: List[str] = None, + domain: str = "", + page: int = 0, + count: int = 100, ) -> List[FinancialInstitutionDao]: async with session.begin(): stmt = ( @@ -23,7 +27,9 @@ async def get_institutions( .limit(count) .offset(page * count) ) - if d := domain.strip(): + if leis: + stmt = stmt.filter(FinancialInstitutionDao.lei.in_(leis)) + elif d := domain.strip(): search = "%{}%".format(d) stmt = stmt.join( FinancialInstitutionDomainDao, @@ -43,13 +49,9 @@ async def get_institution(session: AsyncSession, lei: str) -> FinancialInstituti return await session.scalar(stmt) -async def upsert_institution( - session: AsyncSession, fi: FinancialInstitutionDto -) -> FinancialInstitutionDao: +async def upsert_institution(session: AsyncSession, fi: FinancialInstitutionDto) -> FinancialInstitutionDao: async with session.begin(): - stmt = select(FinancialInstitutionDao).filter( - FinancialInstitutionDao.lei == fi.lei - ) + stmt = select(FinancialInstitutionDao).filter(FinancialInstitutionDao.lei == fi.lei) res = await session.execute(stmt) db_fi = res.scalar_one_or_none() if db_fi is None: diff --git a/src/main.py b/src/main.py index 3fd265a..bc54b7a 100644 --- a/src/main.py +++ b/src/main.py @@ -19,19 +19,13 @@ @app.exception_handler(HTTPException) -async def http_exception_handler( - request: Request, exception: HTTPException -) -> JSONResponse: +async def http_exception_handler(request: Request, exception: HTTPException) -> JSONResponse: log.error(exception, exc_info=True, stack_info=True) - return JSONResponse( - status_code=exception.status_code, content={"detail": exception.detail} - ) + return JSONResponse(status_code=exception.status_code, content={"detail": exception.detail}) @app.exception_handler(Exception) -async def general_exception_handler( - request: Request, exception: Exception -) -> JSONResponse: +async def general_exception_handler(request: Request, exception: Exception) -> JSONResponse: log.error(exception, exc_info=True, stack_info=True) return JSONResponse( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, @@ -39,13 +33,9 @@ async def general_exception_handler( ) -oauth2_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL") -) +oauth2_scheme = OAuth2AuthorizationCodeBearer(authorizationUrl=os.getenv("AUTH_URL"), tokenUrl=os.getenv("TOKEN_URL")) -app.add_middleware( - AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme) -) +app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(oauth2_scheme)) app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/src/oauth2/oauth2_admin.py b/src/oauth2/oauth2_admin.py index 0b473d1..94fb16e 100644 --- a/src/oauth2/oauth2_admin.py +++ b/src/oauth2/oauth2_admin.py @@ -50,9 +50,7 @@ def update_user(self, user_id: str, payload: Dict[str, Any]) -> None: self._admin.update_user(user_id, payload) except kce.KeycloakError as e: log.exception("Failed to update user: %s", user_id, extra=payload) - raise HTTPException( - status_code=e.response_code, detail="Failed to update user" - ) + raise HTTPException(status_code=e.response_code, detail="Failed to update user") def upsert_group(self, lei: str, name: str) -> str: try: @@ -65,9 +63,7 @@ def upsert_group(self, lei: str, name: str) -> str: return group["id"] except kce.KeycloakError as e: log.exception("Failed to upsert group, lei: %s, name: %s", lei, name) - raise HTTPException( - status_code=e.response_code, detail="Failed to upsert group" - ) + raise HTTPException(status_code=e.response_code, detail="Failed to upsert group") def get_group(self, lei: str) -> Dict[str, Any] | None: try: @@ -80,9 +76,7 @@ def associate_to_group(self, user_id: str, group_id: str) -> None: self._admin.group_user_add(user_id, group_id) except kce.KeycloakError as e: log.exception("Failed to associate user %s to group %s", user_id, group_id) - raise HTTPException( - status_code=e.response_code, detail="Failed to associate user to group" - ) + raise HTTPException(status_code=e.response_code, detail="Failed to associate user to group") def associate_to_lei(self, user_id: str, lei: str) -> None: group = self.get_group(lei) diff --git a/src/oauth2/oauth2_backend.py b/src/oauth2/oauth2_backend.py index 605f3f4..9bffc6d 100644 --- a/src/oauth2/oauth2_backend.py +++ b/src/oauth2/oauth2_backend.py @@ -21,17 +21,13 @@ class BearerTokenAuthBackend(AuthenticationBackend): def __init__(self, token_bearer: OAuth2AuthorizationCodeBearer) -> None: self.token_bearer = token_bearer - async def authenticate( - self, conn: HTTPConnection - ) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]: + async def authenticate(self, conn: HTTPConnection) -> Coroutine[Any, Any, Tuple[AuthCredentials, BaseUser] | None]: try: token = await self.token_bearer(conn) claims = oauth2_admin.get_claims(token) if claims is not None: auths = ( - self.extract_nested( - claims, "resource_access", "realm-management", "roles" - ) + self.extract_nested(claims, "resource_access", "realm-management", "roles") + self.extract_nested(claims, "resource_access", "account", "roles") + ["authenticated"] ) diff --git a/src/routers/institutions.py b/src/routers/institutions.py index e54ed32..1de7b94 100644 --- a/src/routers/institutions.py +++ b/src/routers/institutions.py @@ -2,6 +2,7 @@ from http import HTTPStatus from oauth2 import oauth2_admin from util import Router +from dependencies import parse_leis from typing import Annotated, List, Tuple from entities.engine import get_session from entities.repos import institutions_repo as repo @@ -15,9 +16,7 @@ from starlette.authentication import requires -async def set_db( - request: Request, session: Annotated[AsyncSession, Depends(get_session)] -): +async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_session)]): request.state.db_session = session @@ -28,11 +27,12 @@ async def set_db( @requires("authenticated") async def get_institutions( request: Request, + leis: List[str] = Depends(parse_leis), domain: str = "", page: int = 0, count: int = 100, ): - return await repo.get_institutions(request.state.db_session, domain, page, count) + return await repo.get_institutions(request.state.db_session, leis, domain, page, count) @router.post("/", response_model=Tuple[str, FinancialInstitutionDto]) diff --git a/src/util/router_wrapper.py b/src/util/router_wrapper.py index 8efd137..b3cc6f4 100644 --- a/src/util/router_wrapper.py +++ b/src/util/router_wrapper.py @@ -11,9 +11,7 @@ def api_route( if path.endswith("/"): path = path[:-1] - add_path = super().api_route( - path, include_in_schema=include_in_schema, **kwargs - ) + add_path = super().api_route(path, include_in_schema=include_in_schema, **kwargs) add_alt_path = super().api_route(f"{path}/", include_in_schema=False, **kwargs) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 557f8b6..201c5fa 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -35,9 +35,7 @@ def authed_user_mock(auth_mock: Mock) -> Mock: "sub": "testuser123", } auth_mock.return_value = ( - AuthCredentials( - ["manage-account", "query-groups", "manage-users", "authenticated"] - ), + AuthCredentials(["manage-account", "query-groups", "manage-users", "authenticated"]), AuthenticatedUser.from_claim(claims), ) return auth_mock diff --git a/tests/api/routers/test_admin_api.py b/tests/api/routers/test_admin_api.py index d1b5f21..211fa0d 100644 --- a/tests/api/routers/test_admin_api.py +++ b/tests/api/routers/test_admin_api.py @@ -14,18 +14,14 @@ def test_get_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): res = client.get("/v1/admin/me") assert res.status_code == 403 - def test_get_me_authed_with_no_institutions( - self, app_fixture: FastAPI, authed_user_mock: Mock - ): + def test_get_me_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): client = TestClient(app_fixture) res = client.get("/v1/admin/me") assert res.status_code == 200 assert res.json().get("name") == "test" assert res.json().get("institutions") == [] - def test_get_me_authed_with_institutions( - self, app_fixture: FastAPI, auth_mock: Mock - ): + def test_get_me_authed_with_institutions(self, app_fixture: FastAPI, auth_mock: Mock): claims = { "name": "test", "preferred_username": "test_user", @@ -44,9 +40,7 @@ def test_get_me_authed_with_institutions( def test_update_me_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) - res = client.put( - "/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"} - ) + res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}) assert res.status_code == 403 def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): @@ -61,14 +55,10 @@ def test_update_me_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): AuthenticatedUser.from_claim(claims), ) client = TestClient(app_fixture) - res = client.put( - "/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"} - ) + res = client.put("/v1/admin/me", json={"firstName": "testFirst", "lastName": "testLast"}) assert res.status_code == 403 - def test_update_me( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): + def test_update_me(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): update_user_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.update_user") update_user_mock.return_value = None client = TestClient(app_fixture) @@ -77,12 +67,8 @@ def test_update_me( update_user_mock.assert_called_once_with("testuser123", data) assert res.status_code == 202 - def test_associate_institutions( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): - associate_lei_mock = mocker.patch( - "oauth2.oauth2_admin.OAuth2Admin.associate_to_lei" - ) + def test_associate_institutions(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + associate_lei_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.associate_to_lei") associate_lei_mock.return_value = None client = TestClient(app_fixture) data = ["testlei1", "testlei2"] diff --git a/tests/api/routers/test_institutions_api.py b/tests/api/routers/test_institutions_api.py index a666e44..7130437 100644 --- a/tests/api/routers/test_institutions_api.py +++ b/tests/api/routers/test_institutions_api.py @@ -9,26 +9,18 @@ class TestInstitutionsApi: - def test_get_institutions_unauthed( - self, app_fixture: FastAPI, unauthed_user_mock: Mock - ): + def test_get_institutions_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) res = client.get("/v1/institutions/") assert res.status_code == 403 - def test_get_institutions_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): - get_institutions_mock = mocker.patch( - "entities.repos.institutions_repo.get_institutions" - ) + def test_get_institutions_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + get_institutions_mock = mocker.patch("entities.repos.institutions_repo.get_institutions") get_institutions_mock.return_value = [ FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], ) ] client = TestClient(app_fixture) @@ -36,40 +28,26 @@ def test_get_institutions_authed( assert res.status_code == 200 assert res.json()[0].get("name") == "Test Bank 123" - def test_create_institution_unauthed( - self, app_fixture: FastAPI, unauthed_user_mock: Mock - ): + def test_create_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) - res = client.post( - "/v1/institutions/", json={"name": "testName", "lei": "testLei"} - ) + res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) assert res.status_code == 403 - def test_create_institution_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): - upsert_institution_mock = mocker.patch( - "entities.repos.institutions_repo.upsert_institution" - ) + def test_create_institution_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + upsert_institution_mock = mocker.patch("entities.repos.institutions_repo.upsert_institution") upsert_institution_mock.return_value = FinancialInstitutionDao( name="testName", lei="testLei", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], ) upsert_group_mock = mocker.patch("oauth2.oauth2_admin.OAuth2Admin.upsert_group") upsert_group_mock.return_value = "leiGroup" client = TestClient(app_fixture) - res = client.post( - "/v1/institutions/", json={"name": "testName", "lei": "testLei"} - ) + res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) assert res.status_code == 200 assert res.json()[1].get("name") == "testName" - def test_create_institution_authed_no_permission( - self, app_fixture: FastAPI, auth_mock: Mock - ): + def test_create_institution_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): claims = { "name": "test", "preferred_username": "test_user", @@ -81,31 +59,21 @@ def test_create_institution_authed_no_permission( AuthenticatedUser.from_claim(claims), ) client = TestClient(app_fixture) - res = client.post( - "/v1/institutions/", json={"name": "testName", "lei": "testLei"} - ) + res = client.post("/v1/institutions/", json={"name": "testName", "lei": "testLei"}) assert res.status_code == 403 - def test_get_institution_unauthed( - self, app_fixture: FastAPI, unauthed_user_mock: Mock - ): + def test_get_institution_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mock): client = TestClient(app_fixture) lei_path = "testLeiPath" res = client.get(f"/v1/institutions/{lei_path}") assert res.status_code == 403 - def test_get_institution_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): - get_institution_mock = mocker.patch( - "entities.repos.institutions_repo.get_institution" - ) + def test_get_institution_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): + get_institution_mock = mocker.patch("entities.repos.institutions_repo.get_institution") get_institution_mock.return_value = FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")], ) client = TestClient(app_fixture) lei_path = "testLeiPath" @@ -117,30 +85,20 @@ def test_add_domains_unauthed(self, app_fixture: FastAPI, unauthed_user_mock: Mo client = TestClient(app_fixture) lei_path = "testLeiPath" - res = client.post( - f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] - ) + res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 403 - def test_add_domains_authed( - self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock - ): + def test_add_domains_authed(self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock): add_domains_mock = mocker.patch("entities.repos.institutions_repo.add_domains") - add_domains_mock.return_value = [ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ] + add_domains_mock.return_value = [FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123")] client = TestClient(app_fixture) lei_path = "testLeiPath" - res = client.post( - f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] - ) + res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 200 assert res.json()[0].get("domain") == "test.bank" - def test_add_domains_authed_no_permission( - self, app_fixture: FastAPI, auth_mock: Mock - ): + def test_add_domains_authed_no_permission(self, app_fixture: FastAPI, auth_mock: Mock): claims = { "name": "test", "preferred_username": "test_user", @@ -153,9 +111,7 @@ def test_add_domains_authed_no_permission( ) client = TestClient(app_fixture) lei_path = "testLeiPath" - res = client.post( - f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] - ) + res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 403 def test_add_domains_authed_with_denied_email_domain( @@ -165,8 +121,6 @@ def test_add_domains_authed_with_denied_email_domain( domain_denied_mock.return_value = True client = TestClient(app_fixture) lei_path = "testLeiPath" - res = client.post( - f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}] - ) + res = client.post(f"/v1/institutions/{lei_path}/domains/", json=[{"domain": "testDomain"}]) assert res.status_code == 403 assert "domain denied" in res.json()["detail"] diff --git a/tests/entities/conftest.py b/tests/entities/conftest.py index 990d971..fc0633d 100644 --- a/tests/entities/conftest.py +++ b/tests/entities/conftest.py @@ -58,6 +58,4 @@ async def query_session(session_generator: async_scoped_session): @pytest.fixture(scope="function") def session_generator(engine: AsyncEngine): - return async_scoped_session( - async_sessionmaker(engine, expire_on_commit=False), current_task - ) + return async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) diff --git a/tests/entities/repos/test_institutions_repo.py b/tests/entities/repos/test_institutions_repo.py index f042c7f..5acdfd6 100644 --- a/tests/entities/repos/test_institutions_repo.py +++ b/tests/entities/repos/test_institutions_repo.py @@ -16,37 +16,46 @@ async def setup( self, transaction_session: AsyncSession, ): - fi_dao = FinancialInstitutionDao( + fi_dao_123, fi_dao_456 = FinancialInstitutionDao( name="Test Bank 123", lei="TESTBANK123", - domains=[ - FinancialInstitutionDomainDao(domain="test.bank", lei="TESTBANK123") - ], + domains=[FinancialInstitutionDomainDao(domain="test.bank.1", lei="TESTBANK123")], + ), FinancialInstitutionDao( + name="Test Bank 456", + lei="TESTBANK456", + domains=[FinancialInstitutionDomainDao(domain="test.bank.2", lei="TESTBANK456")], ) - transaction_session.add(fi_dao) + transaction_session.add(fi_dao_123) + transaction_session.add(fi_dao_456) await transaction_session.commit() async def test_get_institutions(self, query_session: AsyncSession): res = await repo.get_institutions(query_session) - assert len(res) == 1 + assert len(res) == 2 async def test_get_institutions_by_domain(self, query_session: AsyncSession): - res = await repo.get_institutions(query_session, domain="test.bank") + res = await repo.get_institutions(query_session, domain="test.bank.1") assert len(res) == 1 - async def test_get_institutions_by_domain_not_existing( - self, query_session: AsyncSession - ): + async def test_get_institutions_by_domain_not_existing(self, query_session: AsyncSession): res = await repo.get_institutions(query_session, domain="testing.bank") assert len(res) == 0 + async def test_get_institutions_by_lei_list(self, query_session: AsyncSession): + res = await repo.get_institutions(query_session, leis=["TESTBANK123", "TESTBANK456"]) + assert len(res) == 2 + + async def test_get_institutions_by_lei_list_item_not_existing(self, query_session: AsyncSession): + res = await repo.get_institutions(query_session, leis=["NOTTESTBANK"]) + assert len(res) == 0 + async def test_add_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( transaction_session, FinancialInstitutionDao(name="New Bank 123", lei="NEWBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 2 + assert len(res) == 3 async def test_update_institution(self, transaction_session: AsyncSession): await repo.upsert_institution( @@ -54,12 +63,10 @@ async def test_update_institution(self, transaction_session: AsyncSession): FinancialInstitutionDao(name="Test Bank 234", lei="TESTBANK123"), ) res = await repo.get_institutions(transaction_session) - assert len(res) == 1 + assert len(res) == 2 assert res[0].name == "Test Bank 234" - async def test_add_domains( - self, transaction_session: AsyncSession, query_session: AsyncSession - ): + async def test_add_domains(self, transaction_session: AsyncSession, query_session: AsyncSession): await repo.add_domains( transaction_session, "TESTBANK123", @@ -72,11 +79,5 @@ async def test_domain_allowed(self, transaction_session: AsyncSession): denied_domain = DeniedDomainDao(domain="yahoo.com") transaction_session.add(denied_domain) await transaction_session.commit() - assert ( - await repo.is_email_domain_allowed(transaction_session, "test@yahoo.com") - is False - ) - assert ( - await repo.is_email_domain_allowed(transaction_session, "test@gmail.com") - is True - ) + assert await repo.is_email_domain_allowed(transaction_session, "test@yahoo.com") is False + assert await repo.is_email_domain_allowed(transaction_session, "test@gmail.com") is True