diff --git a/pyproject.toml b/pyproject.toml index 09496e7f..7c02656e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,13 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] asyncio_mode = "auto" pythonpath = ["src"] +env = [ + "INST_DB_SCHEMA=main", + "INST_DB_USER=user", + "INST_DB_PWD=user", + "INST_DB_HOST=localhost", + "INST_DB_NAME=filing" +] addopts = [ "--cov-report=term-missing", "--cov-branch", diff --git a/src/config.py b/src/config.py new file mode 100644 index 00000000..627df8c3 --- /dev/null +++ b/src/config.py @@ -0,0 +1,41 @@ +import os +from urllib import parse +from typing import Any + +from pydantic import field_validator, ValidationInfo +from pydantic.networks import PostgresDsn +from pydantic_settings import BaseSettings, SettingsConfigDict + +env_files_to_load = [".env"] +if os.getenv("ENV", "LOCAL") == "LOCAL": + env_files_to_load.append(".env.local") + + +class Settings(BaseSettings): + inst_db_schema: str = "public" + inst_db_name: str + inst_db_user: str + inst_db_pwd: str + inst_db_host: str + inst_db_scheme: str = "postgresql+asyncpg" + inst_conn: PostgresDsn | None = None + + def __init__(self, **data): + super().__init__(**data) + + @field_validator("inst_conn", mode="before") + @classmethod + def build_postgres_dsn(cls, postgres_dsn, info: ValidationInfo) -> Any: + postgres_dsn = PostgresDsn.build( + scheme=info.data.get("inst_db_scheme"), + username=info.data.get("inst_db_user"), + password=parse.quote(info.data.get("inst_db_pwd"), safe=""), + host=info.data.get("inst_db_host"), + path=info.data.get("inst_db_name"), + ) + return str(postgres_dsn) + + model_config = SettingsConfigDict(env_file=env_files_to_load, extra="allow") + + +settings = Settings() diff --git a/src/entities/engine/__init__.py b/src/entities/engine/__init__.py new file mode 100644 index 00000000..1fa81932 --- /dev/null +++ b/src/entities/engine/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["get_session"] + +from .engine import get_session diff --git a/src/entities/engine/engine.py b/src/entities/engine/engine.py new file mode 100644 index 00000000..9a436825 --- /dev/null +++ b/src/entities/engine/engine.py @@ -0,0 +1,20 @@ +from sqlalchemy.ext.asyncio import ( + create_async_engine, + async_sessionmaker, + async_scoped_session, +) +from asyncio import current_task +from config import settings + +engine = create_async_engine(settings.inst_conn.unicode_string(), echo=True).execution_options( + schema_translate_map={None: settings.inst_db_schema} +) +SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) + + +async def get_session(): + session = SessionLocal() + try: + yield session + finally: + await session.close() diff --git a/src/entities/models/__init__.py b/src/entities/models/__init__.py new file mode 100644 index 00000000..1dec549a --- /dev/null +++ b/src/entities/models/__init__.py @@ -0,0 +1,16 @@ +__all__ = [ + "Base", + "SubmissionDAO", + "SubmissionDTO", + "SubmissionState", + "FilingDAO", + "FilingDTO", + "FilingPeriodDAO", + "FilingPeriodDTO", + "FilingType", + "FilingState", +] + +from .dao import Base, SubmissionDAO, FilingPeriodDAO, FilingDAO +from .dto import SubmissionDTO, FilingDTO, FilingPeriodDTO +from .model_enums import FilingType, FilingState, SubmissionState diff --git a/src/entities/models/dao.py b/src/entities/models/dao.py new file mode 100644 index 00000000..61866294 --- /dev/null +++ b/src/entities/models/dao.py @@ -0,0 +1,68 @@ +from .model_enums import FilingType, FilingState, SubmissionState +from datetime import datetime +from typing import Any +from sqlalchemy import Enum as SAEnum +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.types import JSON + + +class Base(AsyncAttrs, DeclarativeBase): + pass + + +class SubmissionDAO(Base): + __tablename__ = "submission" + id: Mapped[int] = mapped_column(index=True, primary_key=True, autoincrement=True) + submitter: Mapped[str] + state: Mapped[SubmissionState] = mapped_column(SAEnum(SubmissionState)) + validation_ruleset_version: Mapped[str] + validation_json: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=True) + filing: Mapped[str] = mapped_column(ForeignKey("filing.id")) + + def __str__(self): + return f"Submission ID: {self.id}, Submitter: {self.submitter}, State: {self.state}, Ruleset: {self.validation_ruleset_version}, Filing: {self.filing}" + + +class FilingPeriodDAO(Base): + __tablename__ = "filing_period" + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] + start_period: Mapped[datetime] + end_period: Mapped[datetime] + due: Mapped[datetime] + filing_type: Mapped[FilingType] = mapped_column(SAEnum(FilingType)) + + +class FilingDAO(Base): + __tablename__ = "filing" + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + lei: Mapped[str] + state: Mapped[FilingState] = mapped_column(SAEnum(FilingState)) + filing_period: Mapped[int] = mapped_column(ForeignKey("filing_period.id")) + institution_snapshot_id = Mapped[str] # not sure what this is + + +# Commenting out for now since we're just storing the results from the data-validator as JSON. +# If we determine building the data structure for results as tables is needed, we can add these +# back in. +# class FindingDAO(Base): +# __tablename__ = "submission_finding" +# id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) +# submission_id: Mapped[str] = mapped_column(ForeignKey("submission.id")) +# submission: Mapped["SubmissionDAO"] = relationship(back_populates="results") # if we care about bidirectional +# validation_code: Mapped[str] +# severity: Mapped[Severity] = mapped_column(Enum(*get_args(Severity))) +# records: Mapped[List["RecordDAO"]] = relationship(back_populates="result") + + +# class RecordDAO(Base): +# __tablename__ = "submission_finding_record" +# id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) +# result_id: Mapped[str] = mapped_column(ForeignKey("submission_finding.id")) +# result: Mapped["FindingDAO"] = relationship(back_populates="records") # if we care about bidirectional +# record: Mapped[int] +# field_name: Mapped[str] +# field_value: Mapped[str] diff --git a/src/entities/models/dto.py b/src/entities/models/dto.py new file mode 100644 index 00000000..1ec961a6 --- /dev/null +++ b/src/entities/models/dto.py @@ -0,0 +1,36 @@ +from datetime import datetime +from typing import Dict, Any +from pydantic import BaseModel, ConfigDict +from .model_enums import FilingType, FilingState, SubmissionState + + +class SubmissionDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int | None = None + submitter: str + state: SubmissionState | None = None + validation_ruleset_version: str | None = None + validation_json: Dict[str, Any] | None = None + filing: int + + +class FilingDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int | None = None + lei: str + state: FilingState + filing_period: int + institution_snapshot_id: str + + +class FilingPeriodDTO(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int | None = None + name: str + start_period: datetime + end_period: datetime + due: datetime + filing_type: FilingType diff --git a/src/entities/models/model_enums.py b/src/entities/models/model_enums.py new file mode 100644 index 00000000..502b31f5 --- /dev/null +++ b/src/entities/models/model_enums.py @@ -0,0 +1,20 @@ +from enum import Enum + + +class SubmissionState(str, Enum): + SUBMISSION_UPLOADED = "SUBMISSION_UPLOADED" + VALIDATION_IN_PROGRESS = "VALIDATION_IN_PROGRESS" + VALIDATION_WITH_ERRORS = "VALIDATION_WITH_ERRORS" + VALIDATION_WITH_WARNINGS = "VALIDATION_WITH_WARNINGS" + VALIDATION_SUCCESSFUL = "VALIDATION_SUCCESSFUL" + SUBMISSION_SIGNED = "SUBMISSION_SIGNED" + + +class FilingState(str, Enum): + FILING_STARTED = "FILING_STARTED" + FILING_IN_PROGRESS = "FILING_IN_PROGRESS" + FILING_COMPLETE = "FILING_COMPLETE" + + +class FilingType(str, Enum): + MANUAL = "MANUAL" diff --git a/src/entities/repos/__init__.py b/src/entities/repos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/entities/repos/submission_repo.py b/src/entities/repos/submission_repo.py new file mode 100644 index 00000000..19445ccc --- /dev/null +++ b/src/entities/repos/submission_repo.py @@ -0,0 +1,96 @@ +import logging + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from typing import Any, List, TypeVar +from entities.engine import get_session + + +from entities.models import ( + SubmissionDAO, + SubmissionDTO, + SubmissionState, + FilingPeriodDAO, + FilingPeriodDTO, + FilingDTO, + FilingDAO, +) + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +async def get_submissions(session: AsyncSession, filing_id: int = None) -> List[SubmissionDAO]: + async with session.begin(): + stmt = select(SubmissionDAO) + if filing_id: + stmt = stmt.filter(SubmissionDAO.filing == filing_id) + results = await session.scalars(stmt) + return results.all() + + +async def get_submission(session: AsyncSession, submission_id: int) -> SubmissionDAO: + return await query_helper(session, submission_id, SubmissionDAO) + + +async def get_filing(session: AsyncSession, filing_id: int) -> FilingDAO: + return await query_helper(session, filing_id, FilingDAO) + + +async def get_filing_period(session: AsyncSession, filing_period_id: int) -> FilingPeriodDAO: + return await query_helper(session, filing_period_id, FilingPeriodDAO) + + +async def add_submission(session: AsyncSession, submission: SubmissionDTO) -> SubmissionDAO: + async with session.begin(): + new_sub = SubmissionDAO( + filing=submission.filing, + submitter=submission.submitter, + state=SubmissionState.SUBMISSION_UPLOADED, + validation_ruleset_version="v1", + ) + # this returns the attached object, most importantly with the new submission id + new_sub = await session.merge(new_sub) + await session.commit() + return new_sub + + +async def update_submission(submission: SubmissionDAO, incoming_session: AsyncSession = None) -> SubmissionDAO: + session = incoming_session if incoming_session else await anext(get_session()) + async with session.begin(): + try: + new_sub = await session.merge(submission) + await session.commit() + return new_sub + except Exception as e: + await session.rollback() + logger.error(f"There was an exception storing the updated SubmissionDAO, rolling back transaction: {e}") + raise + + +async def upsert_filing_period(session: AsyncSession, filing_period: FilingPeriodDTO) -> FilingPeriodDAO: + return await upsert_helper(session, filing_period, FilingPeriodDAO) + + +async def upsert_filing(session: AsyncSession, filing: FilingDTO) -> FilingDAO: + return await upsert_helper(session, filing, FilingDAO) + + +async def upsert_helper(session: AsyncSession, original_data: Any, type: T) -> T: + async with session.begin(): + copy_data = original_data.__dict__.copy() + # this is only for if a DAO is passed in + # Should be DTOs, but hey, it's python + if copy_data["id"] is not None and "_sa_instance_state" in copy_data: + del copy_data["_sa_instance_state"] + new_dao = type(**copy_data) + new_dao = await session.merge(new_dao) + await session.commit() + return new_dao + + +async def query_helper(session: AsyncSession, id: int, type: T) -> T: + async with session.begin(): + stmt = select(type).filter(type.id == id) + return await session.scalar(stmt) diff --git a/tests/entities/conftest.py b/tests/entities/conftest.py new file mode 100644 index 00000000..fc0633d0 --- /dev/null +++ b/tests/entities/conftest.py @@ -0,0 +1,61 @@ +import asyncio +import pytest + +from asyncio import current_task +from sqlalchemy.ext.asyncio import ( + create_async_engine, + AsyncEngine, + async_scoped_session, + async_sessionmaker, +) +from entities.models import Base + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + try: + yield loop + finally: + loop.close() + + +@pytest.fixture(scope="session") +def engine(): + return create_async_engine("sqlite+aiosqlite://") + + +@pytest.fixture(scope="function", autouse=True) +async def setup_db( + request: pytest.FixtureRequest, + engine: AsyncEngine, + event_loop: asyncio.AbstractEventLoop, +): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + def teardown(): + async def td(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + event_loop.run_until_complete(td()) + + request.addfinalizer(teardown) + + +@pytest.fixture(scope="function") +async def transaction_session(session_generator: async_scoped_session): + async with session_generator() as session: + yield session + + +@pytest.fixture(scope="function") +async def query_session(session_generator: async_scoped_session): + async with session_generator() as session: + yield session + + +@pytest.fixture(scope="function") +def session_generator(engine: AsyncEngine): + return async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task) diff --git a/tests/entities/repos/test_submission_repo.py b/tests/entities/repos/test_submission_repo.py new file mode 100644 index 00000000..0c924ef8 --- /dev/null +++ b/tests/entities/repos/test_submission_repo.py @@ -0,0 +1,229 @@ +import pandas as pd +import pytest + +from datetime import datetime +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session + +from entities.models import ( + SubmissionDAO, + SubmissionDTO, + FilingPeriodDAO, + FilingPeriodDTO, + FilingDAO, + FilingDTO, + FilingType, + FilingState, + SubmissionState, +) +from entities.repos import submission_repo as repo + +from pytest_mock import MockerFixture + +from entities.engine import engine as entities_engine + + +class TestSubmissionRepo: + @pytest.fixture(scope="function", autouse=True) + async def setup( + self, transaction_session: AsyncSession, mocker: MockerFixture, session_generator: async_scoped_session + ): + mocker.patch.object(entities_engine, "SessionLocal", return_value=session_generator) + + filing_period = FilingPeriodDAO( + name="FilingPeriod2024", + start_period=datetime.now(), + end_period=datetime.now(), + due=datetime.now(), + filing_type=FilingType.MANUAL, + ) + transaction_session.add(filing_period) + + filing1 = FilingDAO( + lei="1234567890", state=FilingState.FILING_STARTED, institution_snapshot_id="Snapshot-1", filing_period=1 + ) + filing2 = FilingDAO( + lei="ABCDEFGHIJ", state=FilingState.FILING_STARTED, institution_snapshot_id="Snapshot-1", filing_period=1 + ) + transaction_session.add(filing1) + transaction_session.add(filing2) + + submission1 = SubmissionDAO( + submitter="test1@cfpb.gov", + filing=1, + state=SubmissionState.SUBMISSION_UPLOADED, + validation_ruleset_version="v1", + ) + submission2 = SubmissionDAO( + submitter="test2@cfpb.gov", + filing=2, + state=SubmissionState.SUBMISSION_UPLOADED, + validation_ruleset_version="v1", + ) + submission3 = SubmissionDAO( + submitter="test2@cfpb.gov", + filing=2, + state=SubmissionState.SUBMISSION_UPLOADED, + validation_ruleset_version="v1", + ) + transaction_session.add(submission1) + transaction_session.add(submission2) + transaction_session.add(submission3) + + await transaction_session.commit() + + async def test_add_filing_period(self, transaction_session: AsyncSession): + new_fp = FilingPeriodDTO( + name="FilingPeriod2024.1", + start_period=datetime.now(), + end_period=datetime.now(), + due=datetime.now(), + filing_type=FilingType.MANUAL, + ) + res = await repo.upsert_filing_period(transaction_session, new_fp) + assert res.id == 2 + assert res.filing_type == FilingType.MANUAL + + async def test_get_filing_period(self, query_session: AsyncSession): + res = await repo.get_filing_period(query_session, filing_period_id=1) + assert res.id == 1 + assert res.name == "FilingPeriod2024" + assert res.filing_type == FilingType.MANUAL + + async def test_add_and_modify_filing(self, transaction_session: AsyncSession): + new_filing = FilingDTO( + lei="12345ABCDE", + state=FilingState.FILING_IN_PROGRESS, + institution_snapshot_id="Snapshot-1", + filing_period=1, + ) + res = await repo.upsert_filing(transaction_session, new_filing) + assert res.id == 3 + assert res.lei == "12345ABCDE" + assert res.state == FilingState.FILING_IN_PROGRESS + + mod_filing = FilingDTO( + id=3, + lei="12345ABCDE", + state=FilingState.FILING_COMPLETE, + institution_snapshot_id="Snapshot-1", + filing_period=1, + ) + res = await repo.upsert_filing(transaction_session, mod_filing) + assert res.id == 3 + assert res.lei == "12345ABCDE" + assert res.state == FilingState.FILING_COMPLETE + + async def test_get_filing(self, query_session: AsyncSession): + res = await repo.get_filing_period(query_session, filing_period_id=1) + assert res.id == 1 + assert res.name == "FilingPeriod2024" + assert res.filing_type == FilingType.MANUAL + + async def test_get_submission(self, query_session: AsyncSession): + res = await repo.get_submission(query_session, submission_id=1) + assert res.id == 1 + assert res.submitter == "test1@cfpb.gov" + assert res.filing == 1 + assert res.state == SubmissionState.SUBMISSION_UPLOADED + assert res.validation_ruleset_version == "v1" + + async def test_get_submissions(self, query_session: AsyncSession): + res = await repo.get_submissions(query_session) + assert len(res) == 3 + assert {1, 2, 3} == set([s.id for s in res]) + assert res[0].submitter == "test1@cfpb.gov" + assert res[1].filing == 2 + assert res[2].state == SubmissionState.SUBMISSION_UPLOADED + + res = await repo.get_submissions(query_session, filing_id=2) + assert len(res) == 2 + assert {2, 3} == set([s.id for s in res]) + assert {"test2@cfpb.gov"} == set([s.submitter for s in res]) + assert {2} == set([s.filing for s in res]) + assert {SubmissionState.SUBMISSION_UPLOADED} == set([s.state for s in res]) + + async def test_add_submission(self, transaction_session: AsyncSession): + res = await repo.add_submission(transaction_session, SubmissionDTO(submitter="test@cfpb.gov", filing=1)) + assert res.id == 4 + assert res.submitter == "test@cfpb.gov" + assert res.filing == 1 + assert res.state == SubmissionState.SUBMISSION_UPLOADED + assert res.validation_ruleset_version == "v1" + + async def test_update_submission(self, session_generator: async_scoped_session): + async with session_generator() as add_session: + res = await repo.add_submission(add_session, SubmissionDTO(submitter="test2@cfpb.gov", filing=2)) + + res.state = SubmissionState.VALIDATION_IN_PROGRESS + res = await repo.update_submission(res) + + async def query_updated_dao(): + async with session_generator() as search_session: + stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4) + new_res1 = await search_session.scalar(stmt) + assert new_res1.id == 4 + assert new_res1.filing == 2 + assert new_res1.state == SubmissionState.VALIDATION_IN_PROGRESS + + await query_updated_dao() + + validation_json = self.get_error_json() + res.validation_json = validation_json + res.state = SubmissionState.VALIDATION_WITH_ERRORS + # to test passing in a session to the update_submission function + async with session_generator() as update_session: + res = await repo.update_submission(res, update_session) + + async def query_updated_dao(): + async with session_generator() as search_session: + stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4) + new_res2 = await search_session.scalar(stmt) + assert new_res2.id == 4 + assert new_res2.filing == 2 + assert new_res2.state == SubmissionState.VALIDATION_WITH_ERRORS + assert new_res2.validation_json == validation_json + + await query_updated_dao() + + def get_error_json(self): + df_columns = [ + "record_no", + "field_name", + "field_value", + "validation_severity", + "validation_id", + "validation_name", + "validation_desc", + ] + df_data = [ + [ + 0, + "uid", + "BADUID0", + "error", + "E0001", + "id.invalid_text_length", + "'Unique identifier' must be at least 21 characters in length.", + ], + [ + 0, + "uid", + "BADTEXTLENGTH", + "error", + "E0100", + "ct_credit_product_ff.invalid_text_length", + "'Free-form text field for other credit products' must not exceed 300 characters in length.", + ], + [ + 1, + "uid", + "BADUID1", + "error", + "E0001", + "id.invalid_text_length", + "'Unique identifier' must be at least 21 characters in length.", + ], + ] + error_df = pd.DataFrame(df_data, columns=df_columns) + return error_df.to_json()