Skip to content

Commit

Permalink
Updated to try and shrink down the memory footprint of the dataframes
Browse files Browse the repository at this point in the history
  • Loading branch information
jcadam14 committed Sep 26, 2024
1 parent 37b03d3 commit 452fbd9
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 547 deletions.
521 changes: 30 additions & 491 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ tabulate = "^0.9.0"
ujson = "^5.9.0"
matplotlib = "^3.9.0"
fsspec = "^2024.6.1"
s3fs = "^2024.6.1"
polars = "^1.6.0"
pyarrow = "^17.0.0"
boto3 = "^1.35.27"

[tool.poetry.group.dev.dependencies]
pytest = "8.3.2"
Expand Down
4 changes: 4 additions & 0 deletions src/regtech_data_validator/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def validate(
"""
context_dict = {x.key: x.value for x in context} if context else {}

from datetime import datetime
start = datetime.now()
total_findings = 0
final_phase = ValidationPhase.LOGICAL
all_findings = []
Expand All @@ -99,6 +101,7 @@ def validate(

if all_findings:
final_df = pl.concat(all_findings, how="diagonal")
final_df = final_df.with_columns(phase=pl.lit(final_phase.value))

status = "SUCCESS" if total_findings == 0 else "FAILURE"

Expand All @@ -113,6 +116,7 @@ def validate(
print(df_to_table(final_df))
case OutputFormat.DOWNLOAD:
df_to_download(final_df)
print(f"Took {(datetime.now() - start).total_seconds()} seconds")
case _:
raise ValueError(f'output format "{output}" not supported')

Expand Down
112 changes: 88 additions & 24 deletions src/regtech_data_validator/data_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,32 @@

from io import BytesIO

from regtech_data_validator.checks import SBLCheck
from regtech_data_validator.validation_results import ValidationPhase
from regtech_data_validator.phase_validations import (
get_phase_1_schema_for_lei,
get_phase_2_schema_for_lei,
get_register_schema,
)


def find_check(group_name, checks):
gen = (check for check in checks if check.title == group_name)
return next(gen)


def get_checks(phase):
if phase == ValidationPhase.SYNTACTICAL:
syntax_schema = get_phase_1_schema_for_lei()
checks = [check for col_schema in syntax_schema.columns.values() for check in col_schema.checks]
else:
logic_schema = get_phase_2_schema_for_lei()
checks = [check for col_schema in logic_schema.columns.values() for check in col_schema.checks]
register_schema = get_register_schema()
checks.extend([check for col_schema in register_schema.columns.values() for check in col_schema.checks])
return checks


# Takes the error dataframe, which is a bit obscure, and translates it to a format of:
# validation_type, validation_id, validation_name, row, unique_identifier, fig_link, validation_description, scope, field_#, value_#
# which corresponds to severity, error/warning code, name of error/warning, row number in sblar, UID, fig link,
Expand Down Expand Up @@ -54,10 +74,10 @@ def format_findings(df: pl.DataFrame, checks):
df_pivot = df_pivot.with_columns(
validation_type=pl.lit(check.severity.value),
validation_id=pl.lit(validation_id),
validation_description=pl.lit(check.description),
validation_name=pl.lit(check.name),
fig_link=pl.lit(check.fig_link),
scope=pl.lit(check.scope),
#validation_description=pl.lit(check.description),
#validation_name=pl.lit(check.name),
#fig_link=pl.lit(check.fig_link),
#scope=pl.lit(check.scope),
).rename(
{
"record_no": "row",
Expand All @@ -83,21 +103,16 @@ def format_findings(df: pl.DataFrame, checks):
[
"validation_type",
"validation_id",
"validation_name",
"row",
"unique_identifier",
"fig_link",
"validation_description",
"scope",
]
+ sorted_columns
)
final_df = pl.concat([final_df, df_pivot], how="diagonal")
print(f"Final DF: {final_df}")
return final_df


def df_to_download(df: pl.DataFrame, path: str = "download_report.csv"):
def df_to_download(df: pl.DataFrame, path: str = "download_report.csv", warning_count: int = 0, error_count: int = 0, max_errors: int = 1000000):
if df.is_empty():
# return headers of csv for 'emtpy' report
empty_df = pl.DataFrame(
Expand All @@ -115,20 +130,65 @@ def df_to_download(df: pl.DataFrame, path: str = "download_report.csv"):
empty_df.write_csv(f, quote_style='non_numeric')
return

sorted_df = (
df.with_columns(pl.col('validation_id').cast(pl.Categorical(ordering='lexical')))
#get the check for the phase the results were in, so we can pull out static data from each
#found check
checks = get_checks(df.select(pl.first("phase")).item())

#place the static data into a dataframe, and then join the results frame with it where the validation ids are the same.
#This is much faster than applying the fields
check_values = [{"validation_id": check.title, "validation_description":check.description, "validation_name":check.name, "fig_link":check.fig_link} for check in checks]
checks_df = pl.DataFrame(check_values)
joined_df = df.join(checks_df, on="validation_id")

#Sort by validation id, order the field and value columns so they end up like field_1, value_1, field_2, value_2,...
#and organize the columns as desired for the csv
joined_df = (
joined_df.with_columns(pl.col('validation_id').cast(pl.Categorical(ordering='lexical')))
.sort('validation_id')
.drop(["scope"])
)

field_columns = [col for col in joined_df.columns if col.startswith('field_')]
value_columns = [col for col in joined_df.columns if col.startswith('value_')]
sorted_columns = [col for pair in zip(field_columns, value_columns) for col in pair]

sorted_df = joined_df[
[
"validation_type",
"validation_id",
"validation_name",
"row",
"unique_identifier",
"fig_link",
"validation_description",
]
+ sorted_columns
]

buffer = BytesIO()
headers = ','.join(sorted_df.columns) + '\n'
buffer.write(headers.encode())


total_errors = warning_count + error_count
error_type = "errors"
if warning_count > 0:
if error_count > 0:
error_type = "errors and warnings"
else:
error_type = "warnings"

if total_errors and total_errors > max_errors:
buffer.write(f'"Your register contains {total_errors} {error_type}, however, only {max_errors} records are displayed in this report. To see additional {error_type}, correct the listed records, and upload a new file."\n'.encode())

if path.startswith("s3"):
buffer = BytesIO()
df.write_csv(buffer)
sorted_df.write_csv(buffer, quote_style='non_numeric', include_header=False)
buffer.seek(0)
upload(path, buffer.getvalue())
else:
with fsspec.open(path, mode='wb') as f:
sorted_df.write_csv(f, quote_style='non_numeric')
sorted_df.write_csv(buffer, quote_style='non_numeric', include_header=False)
buffer.seek(0)
f.write(buffer.getvalue())


def upload(path: str, content: bytes) -> None:
Expand Down Expand Up @@ -171,7 +231,10 @@ def df_to_dicts(df: pl.DataFrame, max_records: int = 10000, max_group_size: int
sorted_df = df.with_columns(pl.col('validation_id').cast(pl.Categorical(ordering='lexical'))).sort(
'validation_id'
)
partial_process_group = partial(process_group_data, json_results=json_results, group_size=max_group_size)

checks = get_checks(df.select(pl.first("phase")).item())

partial_process_group = partial(process_group_data, json_results=json_results, group_size=max_group_size, checks=checks)
# collecting just the currently processed group from a lazyframe is faster and more efficient than using "apply"
sorted_df.lazy().group_by('validation_id').map_groups(partial_process_group, schema=None).collect()
json_results = sorted(json_results, key=lambda x: x['validation']['id'])
Expand All @@ -187,17 +250,18 @@ def truncate_validation_group_records(group, group_size):
return truncated_group, need_to_truncate


def process_group_data(group_df, json_results, group_size):
def process_group_data(group_df, json_results, group_size, checks):
validation_id = group_df['validation_id'].item(0)
check = find_check(validation_id, checks)
trunc_group, need_to_truncate = truncate_validation_group_records(group_df, group_size)
group_json = process_chunk(trunc_group, validation_id)
group_json = process_chunk(trunc_group, validation_id, check)
if group_json:
group_json["validation"]["is_truncated"] = need_to_truncate
json_results.append(group_json)
return group_df


def process_chunk(df: pl.DataFrame, validation_id: str) -> [dict]:
def process_chunk(df: pl.DataFrame, validation_id: str, check: SBLCheck) -> [dict]:
# once we have a grouped dataframe, working with the data as a
# python dict is much faster
findings_json = ujson.loads(df.write_json())
Expand All @@ -218,11 +282,11 @@ def process_chunk(df: pl.DataFrame, validation_id: str) -> [dict]:
validation_info = {
'validation': {
'id': validation_id,
'name': first_finding['validation_name'],
'description': first_finding['validation_description'],
'name': check.name,
'description': check.description,
'severity': first_finding['validation_type'],
'scope': first_finding['scope'],
'fig_link': first_finding['fig_link'],
'scope': check.scope,
'fig_link': check.fig_link,
},
'records': records,
}
Expand Down
32 changes: 32 additions & 0 deletions src/regtech_data_validator/phase_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This mapping is used to populate the schema template object and create
an instance of a PanderaSchema object for SYNTACTICAL and LOGICAL phases"""
import pandera.polars as pa

from textwrap import dedent

Expand Down Expand Up @@ -31,8 +32,39 @@
string_contains,
)
from regtech_data_validator.checks import SBLCheck, Severity
from regtech_data_validator.schema_template import get_template, get_register_template
from regtech_data_validator.validation_results import ValidationPhase

# Get separate schema templates for phase 1 and 2
phase_1_template = get_template()
phase_2_template = get_template()
register_template = get_register_template()


def get_schema_by_phase_for_lei(template: dict, phase: str, context: dict[str, str] | None = None):
for column in get_phase_1_and_2_validations_for_lei(context):
validations = get_phase_1_and_2_validations_for_lei(context)[column]
template[column].checks = validations[phase]

return pa.DataFrameSchema(template, name=phase)


def get_phase_1_schema_for_lei(context: dict[str, str] | None = None):
return get_schema_by_phase_for_lei(phase_1_template, ValidationPhase.SYNTACTICAL, context)


def get_phase_2_schema_for_lei(context: dict[str, str] | None = None):
return get_schema_by_phase_for_lei(phase_2_template, ValidationPhase.LOGICAL, context)

# since we process the data in chunks/batch, we need to handle all file/register
# checks separately, as a separate set of schema and checks.
def get_register_schema(context: dict[str, str] | None = None):
for column in get_phase_2_register_validations(context):
validations = get_phase_2_register_validations(context)[column]
register_template[column].checks = validations[ValidationPhase.LOGICAL]

return pa.DataFrameSchema(register_template, name=ValidationPhase.LOGICAL)


# since we process the data in chunks/batch, we need to handle all file/register
# checks separately, as a separate set of schema and checks.
Expand Down
36 changes: 5 additions & 31 deletions src/regtech_data_validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,11 @@
import shutil
import os

# Get separate schema templates for phase 1 and 2
phase_1_template = get_template()
phase_2_template = get_template()
register_template = get_register_template()


def get_schema_by_phase_for_lei(template: dict, phase: str, context: dict[str, str] | None = None):
for column in get_phase_1_and_2_validations_for_lei(context):
validations = get_phase_1_and_2_validations_for_lei(context)[column]
template[column].checks = validations[phase]

return pa.DataFrameSchema(template, name=phase)


def get_phase_1_schema_for_lei(context: dict[str, str] | None = None):
return get_schema_by_phase_for_lei(phase_1_template, ValidationPhase.SYNTACTICAL, context)


def get_phase_2_schema_for_lei(context: dict[str, str] | None = None):
return get_schema_by_phase_for_lei(phase_2_template, ValidationPhase.LOGICAL, context)


# since we process the data in chunks/batch, we need to handle all file/register
# checks separately, as a separate set of schema and checks.
def get_register_schema(context: dict[str, str] | None = None):
for column in get_phase_2_register_validations(context):
validations = get_phase_2_register_validations(context)[column]
register_template[column].checks = validations[ValidationPhase.LOGICAL]

return pa.DataFrameSchema(register_template, name=ValidationPhase.LOGICAL)

from regtech_data_validator.phase_validations import (
get_phase_1_schema_for_lei,
get_phase_2_schema_for_lei,
get_register_schema,
)

# Gets all associated field names from the check
def _get_check_fields(check: Check, primary_column: str) -> list[str]:
Expand Down

0 comments on commit 452fbd9

Please sign in to comment.