Skip to content

Commit

Permalink
Fixed Ruff and black linting
Browse files Browse the repository at this point in the history
  • Loading branch information
jcadam14 committed Dec 12, 2023
1 parent d6c5b6b commit 8b8a954
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 53 deletions.
20 changes: 12 additions & 8 deletions src/entities/models/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,42 @@

Severity = Literal["error", "warning"]


class Base(AsyncAttrs, DeclarativeBase):
pass


class AuditMixin(object):
event_time: Mapped[datetime] = mapped_column(server_default=func.now())



class SubmissionDAO(AuditMixin, Base):
__tablename__ = "submission"
submission_id: Mapped[str] = mapped_column(index=True, primary_key=True)
submitter: Mapped[str]
lei: Mapped[str]
results: Mapped[List["ValidationResultDAO"]] = relationship(back_populates="submission")
json_dump: Mapped[dict[str,Any]] = mapped_column(JSON, nullable=True)
json_dump: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=True)

def __str__(self):
return f"Submission ID: {self.submission_id}, Submitter: {self.submitter}, LEI: {self.lei}"


class ValidationResultDAO(AuditMixin, Base):
__tablename__ = "validation_results"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
submission_id: Mapped[str] = mapped_column(ForeignKey("submission.submission_id"))
submission: Mapped["SubmissionDAO"] = relationship(back_populates="results") # if we care about bidirectional
submission: Mapped["SubmissionDAO"] = relationship(back_populates="results") # if we care about bidirectional
validation_id: Mapped[str]
field_name: Mapped[str]
severity: Mapped[Severity] = mapped_column(Enum(*get_args(Severity)))
records: Mapped[List["RecordDAO"]] = relationship(back_populates="result")


class RecordDAO(AuditMixin, Base):
__tablename__ = "validation_result_record"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
result_id: Mapped[str] = mapped_column(ForeignKey("validation_results.id"))
result: Mapped["ValidationResultDAO"] = relationship(back_populates="records") # if we care about bidirectional
result: Mapped["ValidationResultDAO"] = relationship(back_populates="records") # if we care about bidirectional
record: Mapped[int]
data: Mapped[str]
data: Mapped[str]
11 changes: 6 additions & 5 deletions src/entities/models/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,26 @@
from pydantic import BaseModel, ConfigDict
from starlette.authentication import BaseUser


class RecordDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

record: int
data: str


class ValidationResultDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

validation_id: str
field_name: str
severity: str
records: List[RecordDTO] = []


class SubmissionDTO(BaseModel):
model_config = ConfigDict(from_attributes=True)

submission_id: str
lei: str
submitter: str
Expand Down
37 changes: 15 additions & 22 deletions src/entities/repos/submission_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,40 @@
import pandas as pd
import json

from entities.models import (
SubmissionDAO,
ValidationResultDAO,
RecordDAO
)

async def get_submission(
session: AsyncSession,
submission_id: str
) -> SubmissionDAO:
from entities.models import SubmissionDAO, ValidationResultDAO, RecordDAO


async def get_submission(session: AsyncSession, submission_id: str) -> SubmissionDAO:
async with session.begin():
stmt = (
select(SubmissionDAO)
.options(joinedload(SubmissionDAO.results).joinedload(ValidationResultDAO.records))
.filter(SubmissionDAO.submission_id == submission_id)
)
)
return await session.scalar(stmt)


async def add_submission(
session: AsyncSession,
submission_id: str,
submitter: str,
lei: str,
results: pd.DataFrame
session: AsyncSession, submission_id: str, submitter: str, lei: str, results: pd.DataFrame
) -> SubmissionDAO:
async with session.begin():
findings_by_v_id_df = results.reset_index().set_index(['validation_id'])
findings_by_v_id_df = results.reset_index().set_index(["validation_id"])
submission = SubmissionDAO(submission_id=submission_id, submitter=submitter, lei=lei)
validation_results = []
for v_id_idx, v_id_df in findings_by_v_id_df.groupby(by='validation_id'):
for v_id_idx, v_id_df in findings_by_v_id_df.groupby(by="validation_id"):
v_head = v_id_df.iloc[0]
print(f"Building results for error code {v_id_idx}")
result = ValidationResultDAO(validation_id=v_id_idx, field_name=v_head.at['field_name'], severity=v_head.at['validation_severity'])
result = ValidationResultDAO(
validation_id=v_id_idx, field_name=v_head.at["field_name"], severity=v_head.at["validation_severity"]
)
records = []
for rec_no, rec_df in v_id_df.iterrows():
print(f'{rec_no} Rec Def: {rec_df}')
record = RecordDAO(record=rec_df.at['record_no'], data=rec_df.at['field_value'])
print(f"{rec_no} Rec Def: {rec_df}")
record = RecordDAO(record=rec_df.at["record_no"], data=rec_df.at["field_value"])
records.append(record)
result.records = records
validation_results.append(result)
submission.results = validation_results
session.add(submission)

return submission
return submission
2 changes: 1 addition & 1 deletion tests/entities/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def setup_db(
):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

def teardown():
async def td():
async with engine.begin() as conn:
Expand Down
67 changes: 50 additions & 17 deletions tests/entities/repos/test_submission_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine
from sqlalchemy import select, func

from entities.models import (
SubmissionDAO,
ValidationResultDAO,
RecordDAO
)
from entities.models import SubmissionDAO, ValidationResultDAO, RecordDAO
from entities.repos import submission_repo as repo


Expand All @@ -17,13 +13,11 @@ async def setup(
self,
transaction_session: AsyncSession,
):
submission = SubmissionDAO(submission_id="12345",
submitter="[email protected]",
lei="1234567890ABCDEFGHIJ")
submission = SubmissionDAO(submission_id="12345", submitter="[email protected]", lei="1234567890ABCDEFGHIJ")
results = []
result1 = ValidationResultDAO(validation_id="E0123", field_name="uid", severity="error")
records = []
record1a = RecordDAO(record=1,data="empty")
record1a = RecordDAO(record=1, data="empty")
records.append(record1a)
result1.records = records
results.append(result1)
Expand All @@ -41,22 +35,61 @@ async def test_get_submission(self, query_session: AsyncSession):
assert len(res.results[0].records) == 1
assert res.results[0].validation_id == "E0123"
assert res.results[0].records[0].data == "empty"

async def test_add_submission(self, transaction_session: AsyncSession):
df_columns = ["record_no", "field_name", "field_value", "validation_severity", "validation_id", "validation_name", "validation_desc"]
df_data = [[0, "uid", "BADUID0", "error", "E0001", "id.invalid_text_length", "'Unique identifier' must be at least 21 characters in length."],
[0, "uid", "BADTEXTLENGTH", "error", "E0100", "ct_credit_product_ff.invalid_text_length", "'Free-form text field for other credit products' must not exceed 300 characters in length."],
[1, "uid", "BADUID1", "error", "E0001", "id.invalid_text_length", "'Unique identifier' must be at least 21 characters in length."]]
df_columns = [
"record_no",
"field_name",
"field_value",
"validation_severity",
"validation_id",
"validation_name",
"validation_desc",
]
df_data = [
[
0,
"uid",
"BADUID0",
"error",
"E0001",
"id.invalid_text_length",
"'Unique identifier' must be at least 21 characters in length.",
],
[
0,
"uid",
"BADTEXTLENGTH",
"error",
"E0100",
"ct_credit_product_ff.invalid_text_length",
"'Free-form text field for other credit products' must not exceed 300 characters in length.",
],
[
1,
"uid",
"BADUID1",
"error",
"E0001",
"id.invalid_text_length",
"'Unique identifier' must be at least 21 characters in length.",
],
]
error_df = pd.DataFrame(df_data, columns=df_columns)
print(f"Data Frame: {error_df}")
res = await repo.add_submission(transaction_session, submission_id="12346", submitter="[email protected]", lei="1234567890ABCDEFGHIJ", results=error_df)
res = await repo.add_submission(
transaction_session,
submission_id="12346",
submitter="[email protected]",
lei="1234567890ABCDEFGHIJ",
results=error_df,
)
assert res.submission_id == "12346"
assert res.submitter == "[email protected]"
assert res.lei == "1234567890ABCDEFGHIJ"
assert len(res.results) == 2 # Two error codes, 3 records total
assert len(res.results) == 2 # Two error codes, 3 records total
assert len(res.results[0].records) == 2
assert len(res.results[1].records) == 1
assert res.results[0].validation_id == "E0001"
assert res.results[1].validation_id == "E0100"
assert res.results[0].records[0].data == "BADUID0"

0 comments on commit 8b8a954

Please sign in to comment.