diff --git a/src/entities/repos/submission_repo.py b/src/entities/repos/submission_repo.py index 9c834459..b3d351ea 100644 --- a/src/entities/repos/submission_repo.py +++ b/src/entities/repos/submission_repo.py @@ -1,6 +1,9 @@ +import logging + from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from typing import Any, List, TypeVar +from entities.engine import get_session from entities.models import ( SubmissionDAO, @@ -12,6 +15,8 @@ FilingDAO, ) +logger = logging.getLogger(__name__) + T = TypeVar("T") @@ -24,6 +29,13 @@ async def get_submissions(session: AsyncSession, filing_id: int = None) -> List[ return results.all() +async def get_filing_periods(session: AsyncSession) -> List[FilingPeriodDAO]: + async with session.begin(): + stmt = select(FilingPeriodDAO) + results = await session.scalars(stmt) + return results.all() + + async def get_submission(session: AsyncSession, submission_id: int) -> SubmissionDAO: return await query_helper(session, submission_id, SubmissionDAO) @@ -49,6 +61,19 @@ async def add_submission(session: AsyncSession, submission: SubmissionDTO) -> Su return new_sub +async def update_submission(submission: SubmissionDAO, incoming_session: AsyncSession = None) -> SubmissionDAO: + session = incoming_session if incoming_session else await anext(get_session()) + async with session.begin(): + try: + new_sub = await session.merge(submission) + await session.commit() + return new_sub + except Exception as e: + await session.rollback() + logger.error(f"There was an exception storing the updated SubmissionDAO, rolling back transaction: {e}") + raise + + async def upsert_filing_period(session: AsyncSession, filing_period: FilingPeriodDTO) -> FilingPeriodDAO: return await upsert_helper(session, filing_period, FilingPeriodDAO) diff --git a/src/routers/filing.py b/src/routers/filing.py index e8fa27f8..3f3bb363 100644 --- a/src/routers/filing.py +++ b/src/routers/filing.py @@ -5,7 +5,7 @@ from typing import Annotated, List from entities.engine import get_session -from entities.models import SubmissionDTO +from entities.models import FilingPeriodDTO, SubmissionDTO from entities.repos import submission_repo as repo from sqlalchemy.ext.asyncio import AsyncSession @@ -18,6 +18,11 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_ router = Router(dependencies=[Depends(set_db)]) +@router.get("/periods", response_model=List[FilingPeriodDTO]) +async def get_filing_periods(request: Request): + return await repo.get_filing_periods(request.state.db_session) + + @router.post("/{lei}/submissions/{submission_id}", status_code=HTTPStatus.ACCEPTED) async def upload_file( request: Request, lei: str, submission_id: str, file: UploadFile, background_tasks: BackgroundTasks @@ -28,5 +33,5 @@ async def upload_file( @router.get("/{lei}/filings/{filing_id}/submissions", response_model=List[SubmissionDTO]) -async def get_filing_periods(request: Request, lei: str, filing_id: int): +async def get_submission(request: Request, lei: str, filing_id: int): return await repo.get_submissions(request.state.db_session, filing_id) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 4c594715..4ead5458 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -1,7 +1,11 @@ import pytest +from datetime import datetime from fastapi import FastAPI from pytest_mock import MockerFixture +from unittest.mock import Mock + +from entities.models import FilingPeriodDAO, FilingType @pytest.fixture @@ -9,3 +13,18 @@ def app_fixture(mocker: MockerFixture) -> FastAPI: from main import app return app + + +@pytest.fixture +def get_filing_period_mock(mocker: MockerFixture) -> Mock: + mock = mocker.patch("entities.repos.submission_repo.get_filing_periods") + mock.return_value = [ + FilingPeriodDAO( + name="FilingPeriod2024", + start_period=datetime.now(), + end_period=datetime.now(), + due=datetime.now(), + filing_type=FilingType.MANUAL, + ) + ] + return mock diff --git a/tests/api/routers/test_filing_api.py b/tests/api/routers/test_filing_api.py index 9d4c7ac1..4b622778 100644 --- a/tests/api/routers/test_filing_api.py +++ b/tests/api/routers/test_filing_api.py @@ -1,4 +1,4 @@ -from unittest.mock import ANY +from unittest.mock import ANY, Mock from fastapi import FastAPI from fastapi.testclient import TestClient @@ -8,6 +8,13 @@ class TestFilingApi: + def test_get_periods(self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock): + client = TestClient(app_fixture) + res = client.get("/v1/filing/periods") + assert res.status_code == 200 + assert len(res.json()) == 1 + assert res.json()[0]["name"] == "FilingPeriod2024" + async def test_get_submissions(self, mocker: MockerFixture, app_fixture: FastAPI): mock = mocker.patch("entities.repos.submission_repo.get_submissions") mock.return_value = [ diff --git a/tests/entities/repos/test_submission_repo.py b/tests/entities/repos/test_submission_repo.py index 0f6d7ca3..8a75b022 100644 --- a/tests/entities/repos/test_submission_repo.py +++ b/tests/entities/repos/test_submission_repo.py @@ -3,7 +3,7 @@ from datetime import datetime from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session from entities.models import ( SubmissionDAO, @@ -17,14 +17,18 @@ SubmissionState, ) from entities.repos import submission_repo as repo +from pytest_mock import MockerFixture + +from entities.engine import engine as entities_engine class TestSubmissionRepo: @pytest.fixture(scope="function", autouse=True) async def setup( - self, - transaction_session: AsyncSession, + self, transaction_session: AsyncSession, mocker: MockerFixture, session_generator: async_scoped_session ): + mocker.patch.object(entities_engine, "SessionLocal", return_value=session_generator) + filing_period = FilingPeriodDAO( name="FilingPeriod2024", start_period=datetime.now(), @@ -85,6 +89,11 @@ async def test_add_filing_period(self, transaction_session: AsyncSession): assert res.id == 2 assert res.filing_type == FilingType.MANUAL + async def test_get_filing_periods(self, query_session: AsyncSession): + res = await repo.get_filing_periods(query_session) + assert len(res) == 1 + assert res[0].name == "FilingPeriod2024" + async def test_get_filing_period(self, query_session: AsyncSession): res = await repo.get_filing_period(query_session, filing_period_id=1) assert res.id == 1 @@ -158,27 +167,40 @@ async def test_add_submission(self, transaction_session: AsyncSession): assert res.filing == 1 assert res.state == SubmissionState.SUBMISSION_UPLOADED - async def test_update_submission(self, transaction_session: AsyncSession): - res = await repo.add_submission(transaction_session, SubmissionDTO(submitter="test2@cfpb.gov", filing=2)) + async def test_update_submission(self, session_generator: async_scoped_session): + async with session_generator() as add_session: + res = await repo.add_submission(add_session, SubmissionDTO(submitter="test2@cfpb.gov", filing=2)) + res.state = SubmissionState.VALIDATION_IN_PROGRESS + res = await repo.update_submission(res) - stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4) - new_res1 = await transaction_session.scalar(stmt) - assert new_res1.id == 4 - assert new_res1.filing == 2 - assert new_res1.state == SubmissionState.VALIDATION_IN_PROGRESS + async def query_updated_dao(): + async with session_generator() as search_session: + stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4) + new_res1 = await search_session.scalar(stmt) + assert new_res1.id == 4 + assert new_res1.filing == 2 + assert new_res1.state == SubmissionState.VALIDATION_IN_PROGRESS + + await query_updated_dao() validation_json = self.get_error_json() res.validation_json = validation_json res.state = SubmissionState.VALIDATION_WITH_ERRORS - - stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4) - new_res2 = await transaction_session.scalar(stmt) - - assert new_res2.id == 4 - assert new_res2.filing == 2 - assert new_res2.state == SubmissionState.VALIDATION_WITH_ERRORS - assert new_res2.validation_json == validation_json + # to test passing in a session to the update_submission function + async with session_generator() as update_session: + res = await repo.update_submission(res, update_session) + + async def query_updated_dao(): + async with session_generator() as search_session: + stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4) + new_res2 = await search_session.scalar(stmt) + assert new_res2.id == 4 + assert new_res2.filing == 2 + assert new_res2.state == SubmissionState.VALIDATION_WITH_ERRORS + assert new_res2.validation_json == validation_json + + await query_updated_dao() def get_error_json(self): df_columns = [