Skip to content

Commit

Permalink
Merge branch 'main' into 42-add-get-submission-endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jcadam14 committed Jan 31, 2024
2 parents 6696dc0 + 906df07 commit a11c2c7
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 21 deletions.
25 changes: 25 additions & 0 deletions src/entities/repos/submission_repo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, List, TypeVar
from entities.engine import get_session

from entities.models import (
SubmissionDAO,
Expand All @@ -12,6 +15,8 @@
FilingDAO,
)

logger = logging.getLogger(__name__)

T = TypeVar("T")


Expand All @@ -24,6 +29,13 @@ async def get_submissions(session: AsyncSession, filing_id: int = None) -> List[
return results.all()


async def get_filing_periods(session: AsyncSession) -> List[FilingPeriodDAO]:
async with session.begin():
stmt = select(FilingPeriodDAO)
results = await session.scalars(stmt)
return results.all()


async def get_submission(session: AsyncSession, submission_id: int) -> SubmissionDAO:
return await query_helper(session, submission_id, SubmissionDAO)

Expand All @@ -49,6 +61,19 @@ async def add_submission(session: AsyncSession, submission: SubmissionDTO) -> Su
return new_sub


async def update_submission(submission: SubmissionDAO, incoming_session: AsyncSession = None) -> SubmissionDAO:
session = incoming_session if incoming_session else await anext(get_session())
async with session.begin():
try:
new_sub = await session.merge(submission)
await session.commit()
return new_sub
except Exception as e:
await session.rollback()
logger.error(f"There was an exception storing the updated SubmissionDAO, rolling back transaction: {e}")
raise


async def upsert_filing_period(session: AsyncSession, filing_period: FilingPeriodDTO) -> FilingPeriodDAO:
return await upsert_helper(session, filing_period, FilingPeriodDAO)

Expand Down
9 changes: 7 additions & 2 deletions src/routers/filing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Annotated, List

from entities.engine import get_session
from entities.models import SubmissionDTO
from entities.models import FilingPeriodDTO, SubmissionDTO
from entities.repos import submission_repo as repo

from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -18,6 +18,11 @@ async def set_db(request: Request, session: Annotated[AsyncSession, Depends(get_
router = Router(dependencies=[Depends(set_db)])


@router.get("/periods", response_model=List[FilingPeriodDTO])
async def get_filing_periods(request: Request):
return await repo.get_filing_periods(request.state.db_session)


@router.post("/{lei}/submissions/{submission_id}", status_code=HTTPStatus.ACCEPTED)
async def upload_file(
request: Request, lei: str, submission_id: str, file: UploadFile, background_tasks: BackgroundTasks
Expand All @@ -28,5 +33,5 @@ async def upload_file(


@router.get("/{lei}/filings/{filing_id}/submissions", response_model=List[SubmissionDTO])
async def get_filing_periods(request: Request, lei: str, filing_id: int):
async def get_submission(request: Request, lei: str, filing_id: int):
return await repo.get_submissions(request.state.db_session, filing_id)
19 changes: 19 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
import pytest

from datetime import datetime
from fastapi import FastAPI
from pytest_mock import MockerFixture
from unittest.mock import Mock

from entities.models import FilingPeriodDAO, FilingType


@pytest.fixture
def app_fixture(mocker: MockerFixture) -> FastAPI:
from main import app

return app


@pytest.fixture
def get_filing_period_mock(mocker: MockerFixture) -> Mock:
mock = mocker.patch("entities.repos.submission_repo.get_filing_periods")
mock.return_value = [
FilingPeriodDAO(
name="FilingPeriod2024",
start_period=datetime.now(),
end_period=datetime.now(),
due=datetime.now(),
filing_type=FilingType.MANUAL,
)
]
return mock
9 changes: 8 additions & 1 deletion tests/api/routers/test_filing_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import ANY
from unittest.mock import ANY, Mock

from fastapi import FastAPI
from fastapi.testclient import TestClient
Expand All @@ -8,6 +8,13 @@


class TestFilingApi:
def test_get_periods(self, mocker: MockerFixture, app_fixture: FastAPI, get_filing_period_mock: Mock):
client = TestClient(app_fixture)
res = client.get("/v1/filing/periods")
assert res.status_code == 200
assert len(res.json()) == 1
assert res.json()[0]["name"] == "FilingPeriod2024"

async def test_get_submissions(self, mocker: MockerFixture, app_fixture: FastAPI):
mock = mocker.patch("entities.repos.submission_repo.get_submissions")
mock.return_value = [
Expand Down
58 changes: 40 additions & 18 deletions tests/entities/repos/test_submission_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from datetime import datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session

from entities.models import (
SubmissionDAO,
Expand All @@ -17,14 +17,18 @@
SubmissionState,
)
from entities.repos import submission_repo as repo
from pytest_mock import MockerFixture

from entities.engine import engine as entities_engine


class TestSubmissionRepo:
@pytest.fixture(scope="function", autouse=True)
async def setup(
self,
transaction_session: AsyncSession,
self, transaction_session: AsyncSession, mocker: MockerFixture, session_generator: async_scoped_session
):
mocker.patch.object(entities_engine, "SessionLocal", return_value=session_generator)

filing_period = FilingPeriodDAO(
name="FilingPeriod2024",
start_period=datetime.now(),
Expand Down Expand Up @@ -85,6 +89,11 @@ async def test_add_filing_period(self, transaction_session: AsyncSession):
assert res.id == 2
assert res.filing_type == FilingType.MANUAL

async def test_get_filing_periods(self, query_session: AsyncSession):
res = await repo.get_filing_periods(query_session)
assert len(res) == 1
assert res[0].name == "FilingPeriod2024"

async def test_get_filing_period(self, query_session: AsyncSession):
res = await repo.get_filing_period(query_session, filing_period_id=1)
assert res.id == 1
Expand Down Expand Up @@ -158,27 +167,40 @@ async def test_add_submission(self, transaction_session: AsyncSession):
assert res.filing == 1
assert res.state == SubmissionState.SUBMISSION_UPLOADED

async def test_update_submission(self, transaction_session: AsyncSession):
res = await repo.add_submission(transaction_session, SubmissionDTO(submitter="[email protected]", filing=2))
async def test_update_submission(self, session_generator: async_scoped_session):
async with session_generator() as add_session:
res = await repo.add_submission(add_session, SubmissionDTO(submitter="[email protected]", filing=2))

res.state = SubmissionState.VALIDATION_IN_PROGRESS
res = await repo.update_submission(res)

stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4)
new_res1 = await transaction_session.scalar(stmt)
assert new_res1.id == 4
assert new_res1.filing == 2
assert new_res1.state == SubmissionState.VALIDATION_IN_PROGRESS
async def query_updated_dao():
async with session_generator() as search_session:
stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4)
new_res1 = await search_session.scalar(stmt)
assert new_res1.id == 4
assert new_res1.filing == 2
assert new_res1.state == SubmissionState.VALIDATION_IN_PROGRESS

await query_updated_dao()

validation_json = self.get_error_json()
res.validation_json = validation_json
res.state = SubmissionState.VALIDATION_WITH_ERRORS

stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4)
new_res2 = await transaction_session.scalar(stmt)

assert new_res2.id == 4
assert new_res2.filing == 2
assert new_res2.state == SubmissionState.VALIDATION_WITH_ERRORS
assert new_res2.validation_json == validation_json
# to test passing in a session to the update_submission function
async with session_generator() as update_session:
res = await repo.update_submission(res, update_session)

async def query_updated_dao():
async with session_generator() as search_session:
stmt = select(SubmissionDAO).filter(SubmissionDAO.id == 4)
new_res2 = await search_session.scalar(stmt)
assert new_res2.id == 4
assert new_res2.filing == 2
assert new_res2.state == SubmissionState.VALIDATION_WITH_ERRORS
assert new_res2.validation_json == validation_json

await query_updated_dao()

def get_error_json(self):
df_columns = [
Expand Down

0 comments on commit a11c2c7

Please sign in to comment.