Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

21 add update submission function to submission repo #25

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
asyncio_mode = "auto"
pythonpath = ["src"]
env = [
"INST_DB_SCHEMA=main",
"INST_DB_USER=user",
"INST_DB_PWD=user",
"INST_DB_HOST=localhost",
"INST_DB_NAME=filing"
]
addopts = [
"--cov-report=term-missing",
"--cov-branch",
Expand Down
41 changes: 41 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from urllib import parse
from typing import Any

from pydantic import field_validator, ValidationInfo
from pydantic.networks import PostgresDsn
from pydantic_settings import BaseSettings, SettingsConfigDict

env_files_to_load = [".env"]
if os.getenv("ENV", "LOCAL") == "LOCAL":
env_files_to_load.append(".env.local")


class Settings(BaseSettings):
inst_db_schema: str = "public"
inst_db_name: str
inst_db_user: str
inst_db_pwd: str
inst_db_host: str
inst_db_scheme: str = "postgresql+asyncpg"
inst_conn: PostgresDsn | None = None

def __init__(self, **data):
super().__init__(**data)

@field_validator("inst_conn", mode="before")
@classmethod
def build_postgres_dsn(cls, postgres_dsn, info: ValidationInfo) -> Any:
postgres_dsn = PostgresDsn.build(
scheme=info.data.get("inst_db_scheme"),
username=info.data.get("inst_db_user"),
password=parse.quote(info.data.get("inst_db_pwd"), safe=""),
host=info.data.get("inst_db_host"),
path=info.data.get("inst_db_name"),
)
return str(postgres_dsn)

model_config = SettingsConfigDict(env_file=env_files_to_load, extra="allow")


settings = Settings()
3 changes: 3 additions & 0 deletions src/entities/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__all__ = ["get_session"]

from .engine import get_session
20 changes: 20 additions & 0 deletions src/entities/engine/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sqlalchemy.ext.asyncio import (
create_async_engine,
async_sessionmaker,
async_scoped_session,
)
from asyncio import current_task
from config import settings

engine = create_async_engine(settings.inst_conn.unicode_string(), echo=True).execution_options(
schema_translate_map={None: settings.inst_db_schema}
)
SessionLocal = async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task)


async def get_session():
session = SessionLocal()
try:
yield session
finally:
await session.close()
16 changes: 16 additions & 0 deletions src/entities/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
__all__ = [
"Base",
"SubmissionDAO",
"SubmissionDTO",
"SubmissionState",
"FilingDAO",
"FilingDTO",
"FilingPeriodDAO",
"FilingPeriodDTO",
"FilingType",
"FilingState",
]

from .dao import Base, SubmissionDAO, FilingPeriodDAO, FilingDAO
from .dto import SubmissionDTO, FilingDTO, FilingPeriodDTO
from .model_enums import FilingType, FilingState, SubmissionState
68 changes: 68 additions & 0 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from .model_enums import FilingType, FilingState, SubmissionState
from datetime import datetime
from typing import Any
from sqlalchemy import Enum as SAEnum
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.types import JSON


class Base(AsyncAttrs, DeclarativeBase):
pass


class SubmissionDAO(Base):
__tablename__ = "submission"
id: Mapped[int] = mapped_column(index=True, primary_key=True, autoincrement=True)
submitter: Mapped[str]
state: Mapped[SubmissionState] = mapped_column(SAEnum(SubmissionState))
validation_ruleset_version: Mapped[str]
validation_json: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=True)
filing: Mapped[str] = mapped_column(ForeignKey("filing.id"))

def __str__(self):
return f"Submission ID: {self.id}, Submitter: {self.submitter}, State: {self.state}, Ruleset: {self.validation_ruleset_version}, Filing: {self.filing}"


class FilingPeriodDAO(Base):
__tablename__ = "filing_period"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
name: Mapped[str]
start_period: Mapped[datetime]
end_period: Mapped[datetime]
due: Mapped[datetime]
filing_type: Mapped[FilingType] = mapped_column(SAEnum(FilingType))


class FilingDAO(Base):
__tablename__ = "filing"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
lei: Mapped[str]
state: Mapped[FilingState] = mapped_column(SAEnum(FilingState))
filing_period: Mapped[int] = mapped_column(ForeignKey("filing_period.id"))
institution_snapshot_id = Mapped[str] # not sure what this is


# Commenting out for now since we're just storing the results from the data-validator as JSON.
# If we determine building the data structure for results as tables is needed, we can add these
# back in.
# class FindingDAO(Base):
# __tablename__ = "submission_finding"
# id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
# submission_id: Mapped[str] = mapped_column(ForeignKey("submission.id"))
# submission: Mapped["SubmissionDAO"] = relationship(back_populates="results") # if we care about bidirectional
# validation_code: Mapped[str]
# severity: Mapped[Severity] = mapped_column(Enum(*get_args(Severity)))
# records: Mapped[List["RecordDAO"]] = relationship(back_populates="result")


# class RecordDAO(Base):
# __tablename__ = "submission_finding_record"
# id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
# result_id: Mapped[str] = mapped_column(ForeignKey("submission_finding.id"))
# result: Mapped["FindingDAO"] = relationship(back_populates="records") # if we care about bidirectional
# record: Mapped[int]
# field_name: Mapped[str]
# field_value: Mapped[str]
36 changes: 36 additions & 0 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from datetime import datetime
from typing import Dict, Any
from pydantic import BaseModel, ConfigDict
from .model_enums import FilingType, FilingState, SubmissionState


class SubmissionDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int | None = None
submitter: str
state: SubmissionState | None = None
validation_ruleset_version: str | None = None
validation_json: Dict[str, Any] | None = None
filing: int


class FilingDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int | None = None
lei: str
state: FilingState
filing_period: int
institution_snapshot_id: str


class FilingPeriodDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int | None = None
name: str
start_period: datetime
end_period: datetime
due: datetime
filing_type: FilingType
20 changes: 20 additions & 0 deletions src/entities/models/model_enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from enum import Enum


class SubmissionState(str, Enum):
SUBMISSION_UPLOADED = "SUBMISSION_UPLOADED"
VALIDATION_IN_PROGRESS = "VALIDATION_IN_PROGRESS"
VALIDATION_WITH_ERRORS = "VALIDATION_WITH_ERRORS"
VALIDATION_WITH_WARNINGS = "VALIDATION_WITH_WARNINGS"
VALIDATION_SUCCESSFUL = "VALIDATION_SUCCESSFUL"
SUBMISSION_SIGNED = "SUBMISSION_SIGNED"


class FilingState(str, Enum):
FILING_STARTED = "FILING_STARTED"
FILING_IN_PROGRESS = "FILING_IN_PROGRESS"
FILING_COMPLETE = "FILING_COMPLETE"


class FilingType(str, Enum):
MANUAL = "MANUAL"
Empty file added src/entities/repos/__init__.py
Empty file.
96 changes: 96 additions & 0 deletions src/entities/repos/submission_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import logging

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, List, TypeVar
from entities.engine import get_session


from entities.models import (
SubmissionDAO,
SubmissionDTO,
SubmissionState,
FilingPeriodDAO,
FilingPeriodDTO,
FilingDTO,
FilingDAO,
)

T = TypeVar("T")

logger = logging.getLogger(__name__)


async def get_submissions(session: AsyncSession, filing_id: int = None) -> List[SubmissionDAO]:
async with session.begin():
stmt = select(SubmissionDAO)
if filing_id:
stmt = stmt.filter(SubmissionDAO.filing == filing_id)
results = await session.scalars(stmt)
return results.all()


async def get_submission(session: AsyncSession, submission_id: int) -> SubmissionDAO:
return await query_helper(session, submission_id, SubmissionDAO)


async def get_filing(session: AsyncSession, filing_id: int) -> FilingDAO:
return await query_helper(session, filing_id, FilingDAO)


async def get_filing_period(session: AsyncSession, filing_period_id: int) -> FilingPeriodDAO:
return await query_helper(session, filing_period_id, FilingPeriodDAO)


async def add_submission(session: AsyncSession, submission: SubmissionDTO) -> SubmissionDAO:
async with session.begin():
new_sub = SubmissionDAO(
filing=submission.filing,
submitter=submission.submitter,
state=SubmissionState.SUBMISSION_UPLOADED,
validation_ruleset_version="v1",
)
# this returns the attached object, most importantly with the new submission id
new_sub = await session.merge(new_sub)
await session.commit()
return new_sub


async def update_submission(submission: SubmissionDAO, incoming_session: AsyncSession = None) -> SubmissionDAO:
session = incoming_session if incoming_session else await anext(get_session())
async with session.begin():
try:
new_sub = await session.merge(submission)
await session.commit()
return new_sub
except Exception as e:
await session.rollback()
logger.error(f"There was an exception storing the updated SubmissionDAO, rolling back transaction: {e}")
raise


async def upsert_filing_period(session: AsyncSession, filing_period: FilingPeriodDTO) -> FilingPeriodDAO:
return await upsert_helper(session, filing_period, FilingPeriodDAO)


async def upsert_filing(session: AsyncSession, filing: FilingDTO) -> FilingDAO:
return await upsert_helper(session, filing, FilingDAO)


async def upsert_helper(session: AsyncSession, original_data: Any, type: T) -> T:
async with session.begin():
copy_data = original_data.__dict__.copy()
# this is only for if a DAO is passed in
# Should be DTOs, but hey, it's python
if copy_data["id"] is not None and "_sa_instance_state" in copy_data:
del copy_data["_sa_instance_state"]
new_dao = type(**copy_data)
new_dao = await session.merge(new_dao)
await session.commit()
return new_dao


async def query_helper(session: AsyncSession, id: int, type: T) -> T:
async with session.begin():
stmt = select(type).filter(type.id == id)
return await session.scalar(stmt)
61 changes: 61 additions & 0 deletions tests/entities/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import asyncio
import pytest

from asyncio import current_task
from sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncEngine,
async_scoped_session,
async_sessionmaker,
)
from entities.models import Base


@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop()
try:
yield loop
finally:
loop.close()


@pytest.fixture(scope="session")
def engine():
return create_async_engine("sqlite+aiosqlite://")


@pytest.fixture(scope="function", autouse=True)
async def setup_db(
request: pytest.FixtureRequest,
engine: AsyncEngine,
event_loop: asyncio.AbstractEventLoop,
):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

def teardown():
async def td():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)

event_loop.run_until_complete(td())

request.addfinalizer(teardown)


@pytest.fixture(scope="function")
async def transaction_session(session_generator: async_scoped_session):
async with session_generator() as session:
yield session


@pytest.fixture(scope="function")
async def query_session(session_generator: async_scoped_session):
async with session_generator() as session:
yield session


@pytest.fixture(scope="function")
def session_generator(engine: AsyncEngine):
return async_scoped_session(async_sessionmaker(engine, expire_on_commit=False), current_task)
Loading
Loading