Skip to content

Commit

Permalink
poc: request action validation concept
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Nov 26, 2024
1 parent 2c65294 commit f185197
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 67 deletions.
5 changes: 4 additions & 1 deletion src/sbl_filing_api/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sbl_filing_api.entities.models.model_enums import FilingType, FilingTaskState, SubmissionState, UserActionType
from datetime import datetime
from typing import Any, List
from sqlalchemy import Enum as SAEnum, String
from sqlalchemy import Enum as SAEnum, String, desc
from sqlalchemy import ForeignKey, func, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, relationship
from sqlalchemy.ext.asyncio import AsyncAttrs
Expand Down Expand Up @@ -115,6 +115,9 @@ class FilingDAO(Base):
tasks: Mapped[List[FilingTaskProgressDAO] | None] = relationship(lazy="selectin", cascade="all, delete-orphan")
institution_snapshot_id: Mapped[str] = mapped_column(nullable=True)
contact_info: Mapped[ContactInfoDAO] = relationship("ContactInfoDAO", lazy="joined")
submissions: Mapped[List[SubmissionDAO] | None] = relationship(
"SubmissionDAO", lazy="select", order_by=desc(SubmissionDAO.submission_time)
)
signatures: Mapped[List[UserActionDAO] | None] = relationship(
"UserActionDAO", secondary="filing_signature", lazy="selectin", order_by="desc(UserActionDAO.timestamp)"
)
Expand Down
32 changes: 1 addition & 31 deletions src/sbl_filing_api/entities/repos/submission_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from regtech_api_commons.models.auth import AuthenticatedUser

from copy import deepcopy

from async_lru import alru_cache

from sbl_filing_api.entities.models.dao import (
Expand Down Expand Up @@ -64,23 +62,17 @@ async def get_submission_by_counter(session: AsyncSession, lei: str, filing_peri

async def get_filing(session: AsyncSession, lei: str, filing_period: str) -> FilingDAO:
result = await query_helper(session, FilingDAO, lei=lei, filing_period=filing_period)
if result:
result = await populate_missing_tasks(session, result)
return result[0] if result else None


async def get_filings(session: AsyncSession, leis: list[str], filing_period: str) -> list[FilingDAO]:
stmt = select(FilingDAO).filter(FilingDAO.lei.in_(leis), FilingDAO.filing_period == filing_period)
result = (await session.scalars(stmt)).all()
if result:
result = await populate_missing_tasks(session, result)
return result if result else []


async def get_period_filings(session: AsyncSession, filing_period: str) -> List[FilingDAO]:
filings = await query_helper(session, FilingDAO, filing_period=filing_period)
if filings:
filings = await populate_missing_tasks(session, filings)
return filings


Expand Down Expand Up @@ -148,9 +140,7 @@ async def upsert_filing(session: AsyncSession, filing: FilingDTO) -> FilingDAO:

async def create_new_filing(session: AsyncSession, lei: str, filing_period: str, creator_id: int) -> FilingDAO:
new_filing = FilingDAO(filing_period=filing_period, lei=lei, creator_id=creator_id)
new_filing = await upsert_helper(session, new_filing, FilingDAO)
new_filing = await populate_missing_tasks(session, [new_filing])
return new_filing[0]
return await upsert_helper(session, new_filing, FilingDAO)


async def update_task_state(
Expand Down Expand Up @@ -202,23 +192,3 @@ async def query_helper(session: AsyncSession, table_obj: T, **filter_args) -> Li
if filter_args:
stmt = stmt.filter_by(**filter_args)
return (await session.scalars(stmt)).all()


async def populate_missing_tasks(session: AsyncSession, filings: List[FilingDAO]):
filing_tasks = await get_filing_tasks(session)
filings_copy = deepcopy(filings)
for f in filings_copy:
tasks = [t.task.name for t in f.tasks]
missing_tasks = [t for t in filing_tasks if t.name not in tasks]
for mt in missing_tasks:
f.tasks.append(
FilingTaskProgressDAO(
filing=f.id,
task_name=mt.name,
task=mt,
state=FilingTaskState.NOT_STARTED,
user="",
)
)

return filings_copy
48 changes: 13 additions & 35 deletions src/sbl_filing_api/routers/filing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from regtech_api_commons.api.exceptions import RegTechHttpException
from regtech_api_commons.models.auth import AuthenticatedUser

from sbl_filing_api.entities.models.dao import FilingDAO
from sbl_filing_api.entities.models.model_enums import UserActionType
from sbl_filing_api.services import submission_processor
from sbl_filing_api.services.multithread_handler import handle_submission
Expand All @@ -37,6 +38,8 @@

from regtech_api_commons.api.dependencies import verify_user_lei_relation

from sbl_filing_api.services.request_action_validator import UserActionContext, validate_user_action, set_context

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -119,43 +122,18 @@ async def post_filing(request: Request, lei: str, period_code: str):
)


@router.put("/institutions/{lei}/filings/{period_code}/sign", response_model=FilingDTO)
@router.put(
"/institutions/{lei}/filings/{period_code}/sign",
response_model=FilingDTO,
dependencies=[
Depends(set_context({UserActionContext.INSTITUTION, UserActionContext.FILING})),
Depends(validate_user_action({UserActionType.SIGN})),
],
)
@requires("authenticated")
async def sign_filing(request: Request, lei: str, period_code: str):
filing = await repo.get_filing(request.state.db_session, lei, period_code)
if not filing:
raise RegTechHttpException(
status_code=status.HTTP_404_NOT_FOUND,
name="Filing Not Found",
detail=f"There is no Filing for LEI {lei} in period {period_code}, unable to sign a non-existent Filing.",
)
latest_sub = await repo.get_latest_submission(request.state.db_session, lei, period_code)
if not latest_sub or latest_sub.state != SubmissionState.SUBMISSION_ACCEPTED:
raise RegTechHttpException(
status_code=status.HTTP_403_FORBIDDEN,
name="Filing Action Forbidden",
detail=f"Cannot sign filing. Filing for {lei} for period {period_code} does not have a latest submission the SUBMISSION_ACCEPTED state.",
)
if filing.is_voluntary is None:
raise RegTechHttpException(
status_code=status.HTTP_403_FORBIDDEN,
name="Filing Action Forbidden",
detail=f"Cannot sign filing. Filing for {lei} for period {period_code} does not have a selection of is_voluntary defined.",
)
if not filing.contact_info:
raise RegTechHttpException(
status_code=status.HTTP_403_FORBIDDEN,
name="Filing Action Forbidden",
detail=f"Cannot sign filing. Filing for {lei} for period {period_code} does not have contact info defined.",
)
"""
if not filing.institution_snapshot_id:
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content=f"Cannot sign filing. Filing for {lei} for period {period_code} does not have institution snapshot id defined.",
)
"""

filing: FilingDAO = request.state.context["filing"]
latest_sub = (await filing.awaitable_attrs.submissions)[0]
sig = await repo.add_user_action(
request.state.db_session,
UserActionDTO(
Expand Down
124 changes: 124 additions & 0 deletions src/sbl_filing_api/services/request_action_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import inspect
import json
import logging
from enum import StrEnum
from typing import Any, Dict, List, Set

from async_lru import alru_cache
import httpx
from fastapi import Request, status
from pydantic_settings import BaseSettings
from regtech_api_commons.api.exceptions import RegTechHttpException

from sbl_filing_api.config import settings
from sbl_filing_api.entities.models.dao import FilingDAO, SubmissionDAO
from sbl_filing_api.entities.models.model_enums import SubmissionState, UserActionType
from sbl_filing_api.entities.repos import submission_repo as repo

log = logging.getLogger(__name__)


class UserActionContext(StrEnum):
FILING = "filing"
INSTITUTION = "institution"


# class RequestActionValidationSettings(BaseSettings):
# check_lei_status: bool = True
# check_lei_tin: bool = True
# check_filing_exists: bool = True
# check_sub_accepted: bool = True
# check_voluntary_filer: bool = True
# check_contact_info: bool = True


@alru_cache(ttl=60*60)
async def get_institution_data(request: Request, lei: str):
async with httpx.AsyncClient() as client:
res = await client.get(settings.user_fi_api_url + lei, headers={"authorization": request.headers["authorization"]})
return res.json()


def check_lei_status(institution: Dict[str, Any], **kwargs):
try:
is_active = institution["lei_status"]["can_file"]
if not is_active:
return f"LEI status of {institution['lei_status_code']} cannot file."
except Exception:
log.exception("Unable to determine lei status: %s", json.dumps(institution))
return "Unable to determine LEI status."


def check_lei_tin(institution: Dict[str, Any], **kwargs):
if not institution["tax_id"]:
return "TIN is required to file"


def check_filing_exists(filing: FilingDAO, lei: str, period: str, **kwargs):
if not filing:
return f"There is no Filing for LEI {lei} in period {period}, unable to sign a non-existent Filing."


async def check_sub_accepted(filing: FilingDAO, **kwargs):
submissions: List[SubmissionDAO] = await filing.awaitable_attrs.submissions
if not len(submissions) or submissions[0].state != SubmissionState.SUBMISSION_ACCEPTED:
filing.lei
filing.filing_period
return f"Cannot sign filing. Filing for {filing.lei} for period {filing.filing_period} does not have a latest submission the SUBMISSION_ACCEPTED state."


def check_voluntary_filer(filing: FilingDAO, **kwargs):
if filing.is_voluntary is None:
return f"Cannot sign filing. Filing for {filing.lei} for period {filing.period} does not have a selection of is_voluntary defined."


def check_contact_info(filing: FilingDAO, **kwargs):
if not filing.contact_info:
return f"Cannot sign filing. Filing for {filing.lei} for period {filing.period} does not have contact info defined."


user_action_validation_registry = {
UserActionType.SIGN: {
check_lei_status,
check_lei_tin,
check_filing_exists,
check_sub_accepted,
check_voluntary_filer,
check_contact_info,
}
}


def set_context(requirements: Set[UserActionContext]):
async def _set_context(request: Request):
lei = request.path_params.get("lei")
period = request.path_params.get("period_code")
context = {"lei": lei, "period": period}
if lei and UserActionContext.INSTITUTION in requirements:
context = context | {UserActionContext.INSTITUTION: await get_institution_data(request, lei)}
if period and UserActionContext.FILING in requirements:
context = context | {UserActionContext.FILING: await repo.get_filing(request.state.db_session, lei, period)}
request.state.context = context

return _set_context


def validate_user_action(types: Set[str]):
async def _run_validations(request: Request):
res = []
for type in types:
checkers = user_action_validation_registry[type]
for checker in checkers:
if inspect.iscoroutinefunction(checker):
res.append(await checker(**request.state.context))
else:
res.append(checker(**request.state.context))
res = [r for r in res if r]
if len(res):
raise RegTechHttpException(
status_code=status.HTTP_403_FORBIDDEN,
name="Filing Action Forbidden",
detail=res,
)

return _run_validations

0 comments on commit f185197

Please sign in to comment.