From 56bbd3d3487bab328d9077b9c1fca92317e13c29 Mon Sep 17 00:00:00 2001 From: lchen-2101 <73617864+lchen-2101@users.noreply.github.com> Date: Fri, 6 Dec 2024 10:06:51 -0800 Subject: [PATCH] feat: request action validation framework created, and implemented on sign and submit endpoint (#512) closes #482 closes #515 new context, and validation dependency concept, currently only applying to the `sign` endpoint, can extend to additional endpoints. also removed some of the tasks code since we no longer going with that concept; allowing for some daos to not be detached. --------- Co-authored-by: jcadam14 <41971533+jcadam14@users.noreply.github.com> --- pyproject.toml | 3 +- src/.env.local | 3 +- src/.env.template | 3 +- src/sbl_filing_api/config.py | 17 +- src/sbl_filing_api/entities/models/dao.py | 7 +- .../entities/repos/submission_repo.py | 39 +--- src/sbl_filing_api/routers/filing.py | 49 ++--- .../services/request_action_validator.py | 201 ++++++++++++++++++ tests/api/routers/test_filing_api.py | 153 ++++++------- tests/entities/repos/test_submission_repo.py | 92 -------- .../services/test_request_action_validator.py | 157 ++++++++++++++ 11 files changed, 483 insertions(+), 241 deletions(-) create mode 100644 src/sbl_filing_api/services/request_action_validator.py create mode 100644 tests/services/test_request_action_validator.py diff --git a/pyproject.toml b/pyproject.toml index d344cac3..3fae18c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,8 @@ env = [ "FS_UPLOAD_CONFIG__MKDIR=true", "FS_DOWNLOAD_CONFIG__PROTOCOL=file", "ENV=TEST", - "MAIL_API_URL=http://mail-api:8765/internal/confirmation/send" + "MAIL_API_URL=http://mail-api:8765/internal/confirmation/send", + 'REQUEST_VALIDATORS__SIGN_AND_SUBMIT=["check_lei_status","check_lei_tin","check_filing_exists","check_sub_accepted","check_voluntary_filer","check_contact_info"]' ] testpaths = ["tests"] diff --git a/src/.env.local b/src/.env.local index cf13f323..0d7c13fe 100644 --- a/src/.env.local +++ b/src/.env.local @@ -20,4 +20,5 @@ FS_UPLOAD_CONFIG__PROTOCOL="file" FS_UPLOAD_CONFIG__ROOT="../upload" EXPIRED_SUBMISSION_CHECK_SECS=120 SERVER_CONFIG__RELOAD="true" -MAIL_API_URL=http://mail-api:8765/internal/confirmation/send \ No newline at end of file +MAIL_API_URL=http://mail-api:8765/internal/confirmation/send +REQUEST_VALIDATORS__SIGN_AND_SUBMIT=["check_lei_status","check_lei_tin","check_filing_exists","check_sub_accepted","check_voluntary_filer","check_contact_info"] \ No newline at end of file diff --git a/src/.env.template b/src/.env.template index d57f0cf8..0fb05ab1 100644 --- a/src/.env.template +++ b/src/.env.template @@ -17,4 +17,5 @@ CERTS_URL=${KC_REALM_URL}/protocol/openid-connect/certs FS_UPLOAD_CONFIG__PROTOCOL="file" FS_UPLOAD_CONFIG__ROOT="../upload" USER_FI_API_URL=http://localhost:8881/v1/institutions/ -EXPIRED_SUBMISSION_CHECK_SECS=120 \ No newline at end of file +EXPIRED_SUBMISSION_CHECK_SECS=120 +REQUEST_VALIDATORS__SIGN_AND_SUBMIT=["check_lei_status","check_lei_tin","check_filing_exists","check_sub_accepted","check_voluntary_filer","check_contact_info"] \ No newline at end of file diff --git a/src/sbl_filing_api/config.py b/src/sbl_filing_api/config.py index 6490c6e8..b5e85e61 100644 --- a/src/sbl_filing_api/config.py +++ b/src/sbl_filing_api/config.py @@ -1,7 +1,7 @@ from enum import StrEnum import os from urllib import parse -from typing import Any +from typing import Any, Set from pydantic import field_validator, ValidationInfo, BaseModel from pydantic.networks import PostgresDsn @@ -83,8 +83,23 @@ def build_postgres_dsn(cls, postgres_dsn, info: ValidationInfo) -> Any: model_config = SettingsConfigDict(env_file=env_files_to_load, extra="allow", env_nested_delimiter="__") +class RequestActionValidations(BaseSettings): + sign_and_submit: Set[str] = { + "check_lei_status", + "check_lei_tin", + "check_filing_exists", + "check_sub_accepted", + "check_voluntary_filer", + "check_contact_info", + } + + model_config = SettingsConfigDict(env_prefix="request_validators__", env_file=env_files_to_load, extra="allow") + + settings = Settings() +request_action_validations = RequestActionValidations() + kc_settings = KeycloakSettings(_env_file=env_files_to_load) regex_configs = RegexConfigs.instance() diff --git a/src/sbl_filing_api/entities/models/dao.py b/src/sbl_filing_api/entities/models/dao.py index 0b00928c..e1593fa5 100644 --- a/src/sbl_filing_api/entities/models/dao.py +++ b/src/sbl_filing_api/entities/models/dao.py @@ -1,7 +1,7 @@ from sbl_filing_api.entities.models.model_enums import FilingType, FilingTaskState, SubmissionState, UserActionType from datetime import datetime from typing import Any, List -from sqlalchemy import Enum as SAEnum, String +from sqlalchemy import Enum as SAEnum, String, desc from sqlalchemy import ForeignKey, func, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, relationship from sqlalchemy.ext.asyncio import AsyncAttrs @@ -114,7 +114,10 @@ class FilingDAO(Base): lei: Mapped[str] tasks: Mapped[List[FilingTaskProgressDAO] | None] = relationship(lazy="selectin", cascade="all, delete-orphan") institution_snapshot_id: Mapped[str] = mapped_column(nullable=True) - contact_info: Mapped[ContactInfoDAO] = relationship("ContactInfoDAO", lazy="joined") + contact_info: Mapped[ContactInfoDAO | None] = relationship("ContactInfoDAO", lazy="joined") + submissions: Mapped[List[SubmissionDAO] | None] = relationship( + "SubmissionDAO", lazy="select", order_by=desc(SubmissionDAO.submission_time) + ) signatures: Mapped[List[UserActionDAO] | None] = relationship( "UserActionDAO", secondary="filing_signature", lazy="selectin", order_by="desc(UserActionDAO.timestamp)" ) diff --git a/src/sbl_filing_api/entities/repos/submission_repo.py b/src/sbl_filing_api/entities/repos/submission_repo.py index 942db77b..70580ee2 100644 --- a/src/sbl_filing_api/entities/repos/submission_repo.py +++ b/src/sbl_filing_api/entities/repos/submission_repo.py @@ -7,8 +7,6 @@ from regtech_api_commons.models.auth import AuthenticatedUser -from copy import deepcopy - from async_lru import alru_cache from sbl_filing_api.entities.models.dao import ( @@ -64,23 +62,17 @@ async def get_submission_by_counter(session: AsyncSession, lei: str, filing_peri async def get_filing(session: AsyncSession, lei: str, filing_period: str) -> FilingDAO: result = await query_helper(session, FilingDAO, lei=lei, filing_period=filing_period) - if result: - result = await populate_missing_tasks(session, result) return result[0] if result else None async def get_filings(session: AsyncSession, leis: list[str], filing_period: str) -> list[FilingDAO]: stmt = select(FilingDAO).filter(FilingDAO.lei.in_(leis), FilingDAO.filing_period == filing_period) result = (await session.scalars(stmt)).all() - if result: - result = await populate_missing_tasks(session, result) return result if result else [] async def get_period_filings(session: AsyncSession, filing_period: str) -> List[FilingDAO]: filings = await query_helper(session, FilingDAO, filing_period=filing_period) - if filings: - filings = await populate_missing_tasks(session, filings) return filings @@ -148,9 +140,7 @@ async def upsert_filing(session: AsyncSession, filing: FilingDTO) -> FilingDAO: async def create_new_filing(session: AsyncSession, lei: str, filing_period: str, creator_id: int) -> FilingDAO: new_filing = FilingDAO(filing_period=filing_period, lei=lei, creator_id=creator_id) - new_filing = await upsert_helper(session, new_filing, FilingDAO) - new_filing = await populate_missing_tasks(session, [new_filing]) - return new_filing[0] + return await upsert_helper(session, new_filing, FilingDAO) async def update_task_state( @@ -171,7 +161,12 @@ async def update_contact_info( session: AsyncSession, lei: str, filing_period: str, new_contact_info: ContactInfoDTO ) -> FilingDAO: filing = await get_filing(session, lei=lei, filing_period=filing_period) - filing.contact_info = ContactInfoDAO(**new_contact_info.__dict__.copy(), filing=filing.id) + if filing.contact_info: + for key, value in new_contact_info.__dict__.items(): + if key != "id": + setattr(filing.contact_info, key, value) + else: + filing.contact_info = ContactInfoDAO(**new_contact_info.__dict__.copy(), filing=filing.id) return await upsert_helper(session, filing, FilingDAO) @@ -202,23 +197,3 @@ async def query_helper(session: AsyncSession, table_obj: T, **filter_args) -> Li if filter_args: stmt = stmt.filter_by(**filter_args) return (await session.scalars(stmt)).all() - - -async def populate_missing_tasks(session: AsyncSession, filings: List[FilingDAO]): - filing_tasks = await get_filing_tasks(session) - filings_copy = deepcopy(filings) - for f in filings_copy: - tasks = [t.task.name for t in f.tasks] - missing_tasks = [t for t in filing_tasks if t.name not in tasks] - for mt in missing_tasks: - f.tasks.append( - FilingTaskProgressDAO( - filing=f.id, - task_name=mt.name, - task=mt, - state=FilingTaskState.NOT_STARTED, - user="", - ) - ) - - return filings_copy diff --git a/src/sbl_filing_api/routers/filing.py b/src/sbl_filing_api/routers/filing.py index 8f713a03..7a3adf14 100644 --- a/src/sbl_filing_api/routers/filing.py +++ b/src/sbl_filing_api/routers/filing.py @@ -11,9 +11,11 @@ from regtech_api_commons.api.exceptions import RegTechHttpException from regtech_api_commons.models.auth import AuthenticatedUser +from sbl_filing_api.entities.models.dao import FilingDAO from sbl_filing_api.entities.models.model_enums import UserActionType from sbl_filing_api.services import submission_processor from sbl_filing_api.services.multithread_handler import handle_submission +from sbl_filing_api.config import request_action_validations from typing import Annotated, List from sbl_filing_api.entities.engine.engine import get_session @@ -39,6 +41,8 @@ from sbl_filing_api.services.request_handler import send_confirmation_email +from sbl_filing_api.services.request_action_validator import UserActionContext, validate_user_action, set_context + logger = logging.getLogger(__name__) @@ -121,43 +125,18 @@ async def post_filing(request: Request, lei: str, period_code: str): ) -@router.put("/institutions/{lei}/filings/{period_code}/sign", response_model=FilingDTO) +@router.put( + "/institutions/{lei}/filings/{period_code}/sign", + response_model=FilingDTO, + dependencies=[ + Depends(set_context({UserActionContext.INSTITUTION, UserActionContext.FILING})), + Depends(validate_user_action(request_action_validations.sign_and_submit, "Filing Action Forbidden")), + ], +) @requires("authenticated") async def sign_filing(request: Request, lei: str, period_code: str): - filing = await repo.get_filing(request.state.db_session, lei, period_code) - if not filing: - raise RegTechHttpException( - status_code=status.HTTP_404_NOT_FOUND, - name="Filing Not Found", - detail=f"There is no Filing for LEI {lei} in period {period_code}, unable to sign a non-existent Filing.", - ) - latest_sub = await repo.get_latest_submission(request.state.db_session, lei, period_code) - if not latest_sub or latest_sub.state != SubmissionState.SUBMISSION_ACCEPTED: - raise RegTechHttpException( - status_code=status.HTTP_403_FORBIDDEN, - name="Filing Action Forbidden", - detail=f"Cannot sign filing. Filing for {lei} for period {period_code} does not have a latest submission the SUBMISSION_ACCEPTED state.", - ) - if filing.is_voluntary is None: - raise RegTechHttpException( - status_code=status.HTTP_403_FORBIDDEN, - name="Filing Action Forbidden", - detail=f"Cannot sign filing. Filing for {lei} for period {period_code} does not have a selection of is_voluntary defined.", - ) - if not filing.contact_info: - raise RegTechHttpException( - status_code=status.HTTP_403_FORBIDDEN, - name="Filing Action Forbidden", - detail=f"Cannot sign filing. Filing for {lei} for period {period_code} does not have contact info defined.", - ) - """ - if not filing.institution_snapshot_id: - return JSONResponse( - status_code=status.HTTP_403_FORBIDDEN, - content=f"Cannot sign filing. Filing for {lei} for period {period_code} does not have institution snapshot id defined.", - ) - """ - + filing: FilingDAO = request.state.context["filing"] + latest_sub = (await filing.awaitable_attrs.submissions)[0] sig = await repo.add_user_action( request.state.db_session, UserActionDTO( diff --git a/src/sbl_filing_api/services/request_action_validator.py b/src/sbl_filing_api/services/request_action_validator.py new file mode 100644 index 00000000..c56b9658 --- /dev/null +++ b/src/sbl_filing_api/services/request_action_validator.py @@ -0,0 +1,201 @@ +import inspect +import json +import logging +from abc import ABC, abstractmethod +from enum import StrEnum +from http import HTTPStatus +from typing import Any, Dict, List, Set + +import httpx +from async_lru import alru_cache +from fastapi import Request, status +from regtech_api_commons.api.exceptions import RegTechHttpException + +from sbl_filing_api.config import settings +from sbl_filing_api.entities.models.dao import FilingDAO, SubmissionDAO +from sbl_filing_api.entities.models.model_enums import SubmissionState +from sbl_filing_api.entities.repos import submission_repo as repo + +log = logging.getLogger(__name__) + + +class UserActionContext(StrEnum): + FILING = "filing" + INSTITUTION = "institution" + + +class FiRequest: + """ + FI retrieval request to allow cache to work + """ + + request: Request + lei: str + + def __init__(self, request: Request, lei: str): + self.request = request + self.lei = lei + + def __hash__(self): + return hash(self.lei) + + def __eq__(self, other: "FiRequest"): + return self.lei == other.lei + + +@alru_cache(ttl=60 * 60) +async def get_institution_data(fi_request: FiRequest): + try: + async with httpx.AsyncClient() as client: + res = await client.get( + settings.user_fi_api_url + fi_request.lei, + headers={"authorization": fi_request.request.headers["authorization"]}, + ) + if res.status_code == HTTPStatus.OK: + return res.json() + except Exception: + log.exception("Failed to retrieve fi data for %s", fi_request.lei) + + """ + `alru_cache` seems to cache `None` results, even though documentation for normal `lru_cache` seems to indicate it doesn't cache `None` by default. + So manually invalidate the cache if no returnable result found + """ + get_institution_data.cache_invalidate(fi_request) + + +class ActionValidator(ABC): + """ + Abstract Callable class for action validations, __subclasses__ method leveraged to construct a registry + """ + + name: str + + def __init__(self, name: str): + super().__init__() + self.name = name + + @abstractmethod + def __call__(self, *args, **kwargs): ... + + +class CheckLeiStatus(ActionValidator): + def __init__(self): + super().__init__("check_lei_status") + + def __call__(self, institution: Dict[str, Any], **kwargs): + try: + is_active = institution["lei_status"]["can_file"] + if not is_active: + return f"Cannot sign filing. LEI status of {institution['lei_status_code']} cannot file." + except Exception: + log.exception("Unable to determine lei status: %s", json.dumps(institution)) + return "Unable to determine LEI status." + + +class CheckLeiTin(ActionValidator): + def __init__(self): + super().__init__("check_lei_tin") + + def __call__(self, institution: Dict[str, Any], **kwargs): + if not (institution and institution.get("tax_id")): + return "Cannot sign filing. TIN is required to file." + + +class CheckFilingExists(ActionValidator): + def __init__(self): + super().__init__("check_filing_exists") + + def __call__(self, filing: FilingDAO, lei: str, period: str, **kwargs): + if not filing: + return f"There is no Filing for LEI {lei} in period {period}, unable to sign a non-existent Filing." + + +class CheckSubAccepted(ActionValidator): + def __init__(self): + super().__init__("check_sub_accepted") + + async def __call__(self, filing: FilingDAO, **kwargs): + if filing: + submissions: List[SubmissionDAO] = await filing.awaitable_attrs.submissions + if not len(submissions) or submissions[0].state != SubmissionState.SUBMISSION_ACCEPTED: + filing.lei + filing.filing_period + return f"Cannot sign filing. Filing for {filing.lei} for period {filing.filing_period} does not have a latest submission in the SUBMISSION_ACCEPTED state." + + +class CheckVoluntaryFiler(ActionValidator): + def __init__(self): + super().__init__("check_voluntary_filer") + + def __call__(self, filing: FilingDAO, **kwargs): + if filing and filing.is_voluntary is None: + return f"Cannot sign filing. Filing for {filing.lei} for period {filing.filing_period} does not have a selection of is_voluntary defined." + + +class CheckContactInfo(ActionValidator): + def __init__(self): + super().__init__("check_contact_info") + + def __call__(self, filing: FilingDAO, **kwargs): + if filing and not filing.contact_info: + return f"Cannot sign filing. Filing for {filing.lei} for period {filing.filing_period} does not have contact info defined." + + +validation_registry = { + validator.name: validator for validator in {Validator() for Validator in ActionValidator.__subclasses__()} +} + + +def set_context(requirements: Set[UserActionContext]): + """ + Sets a `context` object on `request.state`; this should typically include the institution, and filing; + `context` should be set before running any validation dependencies + Args: + requst (Request): request from the API endpoint + lei: comes from request path param + period: filing period comes from request path param + """ + + async def _set_context(request: Request): + lei = request.path_params.get("lei") + period = request.path_params.get("period_code") + context = {"lei": lei, "period": period} + if lei and UserActionContext.INSTITUTION in requirements: + context = context | {UserActionContext.INSTITUTION: await get_institution_data(FiRequest(request, lei))} + if period and UserActionContext.FILING in requirements: + context = context | {UserActionContext.FILING: await repo.get_filing(request.state.db_session, lei, period)} + request.state.context = context + + return _set_context + + +def validate_user_action(validator_names: Set[str], exception_name: str): + """ + Runs through list of validators, and aggregate into one exception to allow users know what all the errors are. + + Args: + validator_names (List[str]): list of names of the validators matching the ActionValidator.name attribute, + this is passed in from the endpoint dependency based on RequestActionValidations setting + configurable via `request_validators__` prefixed env vars + """ + + async def _run_validations(request: Request): + res = [] + for validator_name in validator_names: + validator = validation_registry.get(validator_name) + if not validator: + log.warning("Action validator [%s] not found.", validator_name) + elif inspect.iscoroutinefunction(validator.__call__): + res.append(await validator(**request.state.context)) + else: + res.append(validator(**request.state.context)) + + res = [r for r in res if r] + if len(res): + raise RegTechHttpException( + status_code=status.HTTP_403_FORBIDDEN, + name=exception_name, + detail=res, + ) + + return _run_validations diff --git a/tests/api/routers/test_filing_api.py b/tests/api/routers/test_filing_api.py index 73a2a90e..aaa89a97 100644 --- a/tests/api/routers/test_filing_api.py +++ b/tests/api/routers/test_filing_api.py @@ -997,25 +997,26 @@ async def test_accept_submission(self, mocker: MockerFixture, app_fixture: FastA async def test_good_sign_filing( self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock, get_filing_mock: Mock ): - mock = mocker.patch("sbl_filing_api.entities.repos.submission_repo.get_latest_submission") - mock.return_value = SubmissionDAO( - id=1, - counter=5, - submitter=UserActionDAO( + get_filing_mock.return_value.is_voluntary = True + get_filing_mock.return_value.submissions = [ + SubmissionDAO( id=1, - user_id="123456-7890-ABCDEF-GHIJ", - user_name="Test Submitter User", - user_email="test1@cfpb.gov", - action_type=UserActionType.SUBMIT, - timestamp=datetime.datetime.now(), - ), - filing=1, - state=SubmissionState.SUBMISSION_ACCEPTED, - validation_ruleset_version="v1", - submission_time=datetime.datetime.now(), - filename="file1.csv", - ) - get_filing_mock.return_value.is_voluntary = False + counter=5, + submitter=UserActionDAO( + id=1, + user_id="123456-7890-ABCDEF-GHIJ", + user_name="Test Submitter User", + user_email="test1@cfpb.gov", + action_type=UserActionType.SUBMIT, + timestamp=datetime.datetime.now(), + ), + filing=1, + state=SubmissionState.SUBMISSION_ACCEPTED, + validation_ruleset_version="v1", + submission_time=datetime.datetime.now(), + filename="file1.csv", + ) + ] add_sig_mock = mocker.patch("sbl_filing_api.entities.repos.submission_repo.add_user_action") add_sig_mock.return_value = UserActionDAO( @@ -1034,7 +1035,14 @@ async def test_good_sign_filing( updated_filing_obj = deepcopy(get_filing_mock.return_value) upsert_mock.return_value = updated_filing_obj - client = TestClient(app_fixture) + fi_data_mock = mocker.patch("sbl_filing_api.services.request_action_validator.get_institution_data") + fi_data_mock.return_value = { + "tax_id": "12-3456789", + "lei_status_code": "ISSUED", + "lei_status": {"code": "ISSUED", "name": "Issued", "can_file": True}, + } + + client = TestClient(app_fixture, headers={"authorization": "Bearer test123"}) res = client.put("/v1/filing/institutions/1234567890ABCDEFGH00/filings/2024/sign") add_sig_mock.assert_called_with( ANY, @@ -1057,81 +1065,74 @@ async def test_good_sign_filing( async def test_errors_sign_filing( self, mocker: MockerFixture, app_fixture: FastAPI, authed_user_mock: Mock, get_filing_mock: Mock ): - sub_mock = mocker.patch("sbl_filing_api.entities.repos.submission_repo.get_latest_submission") send_email_mock = mocker.patch("sbl_filing_api.services.request_handler.send_confirmation_email") send_email_mock.return_value = None - sub_mock.return_value = SubmissionDAO( - id=1, - submitter=UserActionDAO( + + get_filing_mock.return_value.submissions = [ + SubmissionDAO( id=1, - user_id="1234-5678-ABCD-EFGH", - user_name="Test Submitter User", - user_email="test1@cfpb.gov", - action_type=UserActionType.SUBMIT, - timestamp=datetime.datetime.now(), - ), - filing=1, - state=SubmissionState.VALIDATION_SUCCESSFUL, - validation_ruleset_version="v1", - submission_time=datetime.datetime.now(), - filename="file1.csv", - ) + submitter=UserActionDAO( + id=1, + user_id="1234-5678-ABCD-EFGH", + user_name="Test Submitter User", + user_email="test1@cfpb.gov", + action_type=UserActionType.SUBMIT, + timestamp=datetime.datetime.now(), + ), + filing=1, + state=SubmissionState.VALIDATION_SUCCESSFUL, + validation_ruleset_version="v1", + submission_time=datetime.datetime.now(), + filename="file1.csv", + ) + ] + get_filing_mock.return_value.contact_info = None - client = TestClient(app_fixture) - res = client.put("/v1/filing/institutions/1234567890ABCDEFGH00/filings/2024/sign") - assert res.status_code == 403 - assert ( - res.json()["error_detail"] - == "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a latest submission the SUBMISSION_ACCEPTED state." + add_sig_mock = mocker.patch("sbl_filing_api.entities.repos.submission_repo.add_user_action") + add_sig_mock.return_value = UserActionDAO( + id=2, + user_id="123456-7890-ABCDEF-GHIJ", + user_name="Test User", + user_email="test@local.host", + timestamp=datetime.datetime.now(), + action_type=UserActionType.SIGN, ) - sub_mock.return_value = None - res = client.put("/v1/filing/institutions/1234567890ABCDEFGH00/filings/2024/sign") - assert res.status_code == 403 - assert ( - res.json()["error_detail"] - == "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a latest submission the SUBMISSION_ACCEPTED state." - ) + upsert_mock = mocker.patch("sbl_filing_api.entities.repos.submission_repo.upsert_filing") + updated_filing_obj = deepcopy(get_filing_mock.return_value) + upsert_mock.return_value = updated_filing_obj - sub_mock.return_value = SubmissionDAO( - id=1, - submitter=UserActionDAO( - id=1, - user_id="1234-5678-ABCD-EFGH", - user_name="Test Submitter User", - user_email="test1@cfpb.gov", - action_type=UserActionType.SUBMIT, - timestamp=datetime.datetime.now(), - ), - filing=1, - state=SubmissionState.SUBMISSION_ACCEPTED, - validation_ruleset_version="v1", - submission_time=datetime.datetime.now(), - filename="file1.csv", - ) + fi_data_mock = mocker.patch("sbl_filing_api.services.request_action_validator.get_institution_data") + fi_data_mock.return_value = { + "tax_id": None, + "lei_status_code": "LAPSED", + "lei_status": {"code": "LAPSED", "name": "Lapsed", "can_file": False}, + } + client = TestClient(app_fixture, headers={"authorization": "Bearer test123"}) res = client.put("/v1/filing/institutions/1234567890ABCDEFGH00/filings/2024/sign") assert res.status_code == 403 + errors = res.json()["error_detail"] assert ( - res.json()["error_detail"] - == "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a selection of is_voluntary defined." + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a latest submission in the SUBMISSION_ACCEPTED state." + in errors ) - - get_filing_mock.return_value.is_voluntary = True - get_filing_mock.return_value.contact_info = None - res = client.put("/v1/filing/institutions/1234567890ABCDEFGH00/filings/2024/sign") - assert res.status_code == 403 assert ( - res.json()["error_detail"] - == "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have contact info defined." + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a selection of is_voluntary defined." + in errors + ) + assert ( + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have contact info defined." + in errors ) + assert "Cannot sign filing. TIN is required to file." in errors + assert "Cannot sign filing. LEI status of LAPSED cannot file." in errors get_filing_mock.return_value = None res = client.put("/v1/filing/institutions/1234567890ABCDEFGH00/filings/2024/sign") - assert res.status_code == 404 assert ( - res.json()["error_detail"] - == "There is no Filing for LEI 1234567890ABCDEFGH00 in period 2024, unable to sign a non-existent Filing." + "There is no Filing for LEI 1234567890ABCDEFGH00 in period 2024, unable to sign a non-existent Filing." + in res.json()["error_detail"] ) async def test_get_latest_sub_report( diff --git a/tests/entities/repos/test_submission_repo.py b/tests/entities/repos/test_submission_repo.py index 0b274fe2..ab9bb6b8 100644 --- a/tests/entities/repos/test_submission_repo.py +++ b/tests/entities/repos/test_submission_repo.py @@ -15,7 +15,6 @@ FilingTaskProgressDAO, FilingTaskDAO, FilingType, - FilingTaskState, SubmissionState, ContactInfoDAO, UserActionDAO, @@ -23,7 +22,6 @@ from sbl_filing_api.entities.models.dto import FilingPeriodDTO, ContactInfoDTO, UserActionDTO from sbl_filing_api.entities.models.model_enums import UserActionType from sbl_filing_api.entities.repos import submission_repo as repo -from regtech_api_commons.models.auth import AuthenticatedUser from pytest_mock import MockerFixture @@ -294,64 +292,11 @@ async def test_modify_filing(self, transaction_session: AsyncSession): assert res.creator.user_id == "123456-7890-ABCDEF-GHIJ" assert res.creator.user_name == "test creator" - async def test_get_filing_tasks(self, transaction_session: AsyncSession): - tasks = await repo.get_filing_tasks(transaction_session) - assert len(tasks) == 2 - assert tasks[0].name == "Task-1" - assert tasks[1].name == "Task-2" - - async def test_mod_filing_task(self, query_session: AsyncSession, transaction_session: AsyncSession): - user = AuthenticatedUser.from_claim({"preferred_username": "testuser"}) - await repo.update_task_state( - query_session, lei="1234567890", filing_period="2024", task_name="Task-1", state="COMPLETED", user=user - ) - seconds_now = dt.utcnow().timestamp() - filing = await repo.get_filing(query_session, lei="1234567890", filing_period="2024") - filing_task_states = filing.tasks - - assert len(filing_task_states) == 2 - assert filing_task_states[0].task.name == "Task-1" - assert filing_task_states[0].id == 1 - assert filing_task_states[0].filing == 1 - assert filing_task_states[0].state == FilingTaskState.COMPLETED - assert filing_task_states[0].user == "testuser" - assert filing_task_states[0].change_timestamp.timestamp() == pytest.approx( - seconds_now, abs=1.5 - ) # allow for possible 1.5 second difference - - async def test_add_filing_task(self, query_session: AsyncSession, transaction_session: AsyncSession): - user = AuthenticatedUser.from_claim({"preferred_username": "testuser"}) - await repo.update_task_state( - query_session, lei="1234567890", filing_period="2024", task_name="Task-2", state="IN_PROGRESS", user=user - ) - seconds_now = dt.utcnow().timestamp() - filing = await repo.get_filing(query_session, lei="1234567890", filing_period="2024") - filing_task_states = filing.tasks - - assert len(filing_task_states) == 2 - assert filing_task_states[1].task.name == "Task-2" - assert filing_task_states[1].id == 2 - assert filing_task_states[1].filing == 1 - assert filing_task_states[1].state == FilingTaskState.IN_PROGRESS - assert filing_task_states[1].user == "testuser" - assert filing_task_states[1].change_timestamp.timestamp() == pytest.approx( - seconds_now, abs=1.0 - ) # allow for possible 1 second difference - async def test_get_filing(self, query_session: AsyncSession, mocker: MockerFixture): - spy_populate_missing_tasks = mocker.patch( - "sbl_filing_api.entities.repos.submission_repo.populate_missing_tasks", wraps=repo.populate_missing_tasks - ) res1 = await repo.get_filing(query_session, lei="1234567890", filing_period="2024") assert res1.id == 1 assert res1.filing_period == "2024" assert res1.lei == "1234567890" - assert len(res1.tasks) == 2 - assert FilingTaskState.NOT_STARTED in set([t.state for t in res1.tasks]) - tasks1 = set([task_progress.task for task_progress in res1.tasks]) - assert len(tasks1) == 2 - assert "Task-1" in set([task.name for task in tasks1]) - assert "Task-2" in set([task.name for task in tasks1]) assert len(res1.signatures) == 2 assert res1.signatures[0].id == 5 assert res1.signatures[0].user_id == "test_sig@local.host" @@ -360,53 +305,16 @@ async def test_get_filing(self, query_session: AsyncSession, mocker: MockerFixtu assert res2.id == 2 assert res2.filing_period == "2024" assert res2.lei == "ABCDEFGHIJ" - assert len(res2.tasks) == 2 - assert FilingTaskState.NOT_STARTED in set([t.state for t in res2.tasks]) - tasks2 = set([task_progress.task for task_progress in res2.tasks]) - assert len(tasks2) == 2 - assert "Task-1" in set([task.name for task in tasks2]) - assert "Task-2" in set([task.name for task in tasks2]) - - tasks_populated_filings = [] - for call in spy_populate_missing_tasks.call_args_list: - args, _ = call - filings = args[1] - assert isinstance(filings[0], FilingDAO) - tasks_populated_filings.append(filings[0].id) - assert set(tasks_populated_filings) == set([1, 2]) async def test_get_filings(self, query_session: AsyncSession, mocker: MockerFixture): - spy_populate_missing_tasks = mocker.patch( - "sbl_filing_api.entities.repos.submission_repo.populate_missing_tasks", wraps=repo.populate_missing_tasks - ) res = await repo.get_filings(query_session, leis=["1234567890", "ABCDEFGHIJ"], filing_period="2024") assert res[0].id == 1 assert res[0].filing_period == "2024" assert res[0].lei == "1234567890" - assert len(res[0].tasks) == 2 - assert FilingTaskState.NOT_STARTED in set([t.state for t in res[0].tasks]) - tasks1 = set([task_progress.task for task_progress in res[0].tasks]) - assert len(tasks1) == 2 - assert "Task-1" in set([task.name for task in tasks1]) - assert "Task-2" in set([task.name for task in tasks1]) assert res[1].id == 2 assert res[1].filing_period == "2024" assert res[1].lei == "ABCDEFGHIJ" - assert len(res[1].tasks) == 2 - assert FilingTaskState.NOT_STARTED in set([t.state for t in res[1].tasks]) - tasks2 = set([task_progress.task for task_progress in res[1].tasks]) - assert len(tasks2) == 2 - assert "Task-1" in set([task.name for task in tasks2]) - assert "Task-2" in set([task.name for task in tasks2]) - - tasks_populated_filings = [] - for call in spy_populate_missing_tasks.call_args_list: - args, _ = call - filings = args[1] - assert all([isinstance(f, FilingDAO) for f in filings]) - tasks_populated_filings.extend([f.id for f in filings]) - assert set(tasks_populated_filings) == set([1, 2]) async def test_get_period_filings(self, query_session: AsyncSession, mocker: MockerFixture): results = await repo.get_period_filings(query_session, filing_period="2024") diff --git a/tests/services/test_request_action_validator.py b/tests/services/test_request_action_validator.py new file mode 100644 index 00000000..41cd8d0d --- /dev/null +++ b/tests/services/test_request_action_validator.py @@ -0,0 +1,157 @@ +from http import HTTPStatus +from logging import Logger + +import pytest +from fastapi import Request +from pytest_mock import MockerFixture +from regtech_api_commons.api.exceptions import RegTechHttpException + +from sbl_filing_api.entities.models.dao import ContactInfoDAO, FilingDAO, SubmissionDAO +from sbl_filing_api.entities.models.model_enums import SubmissionState +from sbl_filing_api.services.request_action_validator import UserActionContext, set_context, validate_user_action + + +@pytest.fixture +def httpx_unauthed_mock(mocker: MockerFixture) -> None: + mock_client_get = mocker.patch("httpx.AsyncClient.get") + mock_response = mocker.patch("httpx.Response") + mock_response.status_code = HTTPStatus.FORBIDDEN + mock_client_get.return_value = mock_response + + +@pytest.fixture +def httpx_authed_mock(mocker: MockerFixture) -> None: + mock_client_get = mocker.patch("httpx.AsyncClient.get") + mock_response = mocker.patch("httpx.Response") + mock_response.status_code = HTTPStatus.OK + mock_response.json.return_value = { + "tax_id": "12-3456789", + "lei_status_code": "LAPSED", + "lei_status": {"name": "Lapsed", "code": "LAPSED", "can_file": False}, + } + mock_client_get.return_value = mock_response + + +@pytest.fixture +async def filing_mock(mocker: MockerFixture) -> FilingDAO: + sub_mock = mocker.patch("sbl_filing_api.entities.models.dao.SubmissionDAO") + sub_mock.state = SubmissionState.UPLOAD_FAILED + filing = FilingDAO(lei="1234567890ABCDEFGH00", filing_period="2024", submissions=[sub_mock]) + return filing + + +@pytest.fixture +def request_mock(mocker: MockerFixture) -> Request: + mock = mocker.patch("fastapi.Request") + mock.path_params = {"lei": "1234567890ABCDEFGH00", "period_code": "2024"} + return mock + + +@pytest.fixture +def request_mock_valid_context(mocker: MockerFixture, request_mock: Request, filing_mock: FilingDAO) -> Request: + filing_mock.is_voluntary = True + filing_mock.submissions = [SubmissionDAO(state=SubmissionState.SUBMISSION_ACCEPTED)] + filing_mock.contact_info = ContactInfoDAO() + + request_mock.state.context = { + "lei": "1234567890ABCDEFGH00", + "period": "2024", + UserActionContext.INSTITUTION: { + "tax_id": "12-3456789", + "lei_status_code": "ISSUED", + "lei_status": {"name": "Issued", "code": "ISSUED", "can_file": True}, + }, + UserActionContext.FILING: filing_mock, + } + return request_mock + + +@pytest.fixture +def request_mock_invalid_context(mocker: MockerFixture, request_mock: Request, filing_mock: FilingDAO) -> Request: + request_mock.state.context = { + "lei": "1234567890ABCDEFGH00", + "period": "2024", + UserActionContext.INSTITUTION: { + "lei_status_code": "LAPSED", + "lei_status": {"name": "Lapsed", "code": "LAPSED", "can_file": False}, + }, + UserActionContext.FILING: filing_mock, + } + return request_mock + + +@pytest.fixture +def log_mock(mocker: MockerFixture) -> Logger: + return mocker.patch("sbl_filing_api.services.request_action_validator.log") + + +async def test_validations_with_errors(request_mock_invalid_context: Request): + run_validations = validate_user_action( + { + "check_lei_status", + "check_lei_tin", + "check_filing_exists", + "check_sub_accepted", + "check_voluntary_filer", + "check_contact_info", + }, + "Test Exception", + ) + with pytest.raises(RegTechHttpException) as e: + await run_validations(request_mock_invalid_context) + assert e.value.name == "Test Exception" + errors = e.value.detail + assert ( + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a latest submission in the SUBMISSION_ACCEPTED state." + in errors + ) + assert ( + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have a selection of is_voluntary defined." + in errors + ) + assert ( + "Cannot sign filing. Filing for 1234567890ABCDEFGH00 for period 2024 does not have contact info defined." + in errors + ) + assert "Cannot sign filing. TIN is required to file." in errors + assert "Cannot sign filing. LEI status of LAPSED cannot file." in errors + + +async def test_validations_no_errors(request_mock_valid_context: Request): + run_validations = validate_user_action( + { + "check_lei_status", + "check_lei_tin", + "check_filing_exists", + "check_sub_accepted", + "check_voluntary_filer", + "check_contact_info", + }, + "Test Exception", + ) + await run_validations(request_mock_valid_context) + + +async def test_lei_status_bad_api_res(request_mock: Request, httpx_unauthed_mock): + run_validations = validate_user_action({"check_lei_status"}, "Test Exception") + context_setter = set_context({UserActionContext.INSTITUTION}) + await context_setter(request_mock) + + with pytest.raises(RegTechHttpException) as e: + await run_validations(request_mock) + assert "Unable to determine LEI status." in e.value.detail + + +async def test_lei_status_good_api_res(request_mock: Request, httpx_authed_mock): + run_validations = validate_user_action({"check_lei_status"}, "Test Exception") + context_setter = set_context({UserActionContext.INSTITUTION}) + await context_setter(request_mock) + with pytest.raises(RegTechHttpException) as e: + await run_validations(request_mock) + assert "Cannot sign filing. LEI status of LAPSED cannot file." in e.value.detail + + +async def test_invalid_validation(request_mock_invalid_context: Request, log_mock: Logger): + run_validations = validate_user_action({"fake_validation"}, "Test Exception") + await run_validations(request_mock_invalid_context) + log_mock.warning.assert_called_with("Action validator [%s] not found.", "fake_validation")