Skip to content

Commit

Permalink
validate lazy frame
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Nov 8, 2024
1 parent 4398f7e commit 402064c
Showing 1 changed file with 84 additions and 23 deletions.
107 changes: 84 additions & 23 deletions src/regtech_data_validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
with validations listed in phase 1 and phase 2."""

from pathlib import Path
from typing import Dict, List
import polars as pl
import pandera.polars as pa
from pandera import Check
from pandera.errors import SchemaErrors, SchemaError, SchemaErrorReason
from polars.io.csv.batched_reader import BatchedCsvReader

from regtech_data_validator.checks import SBLCheck, Severity

Expand Down Expand Up @@ -69,8 +71,8 @@ def _add_validation_metadata(failed_check_fields_df: pl.DataFrame, check: SBLChe


def validate(
schema: pa.DataFrameSchema, submission_df: pl.LazyFrame, row_start: int, process_errors: bool
) -> pl.DataFrame:
schema: pa.DataFrameSchema, submission_df: pl.DataFrame, row_start: int, process_errors: bool
) -> ValidationResults:
"""
validate received dataframe with schema and return list of
schema errors
Expand Down Expand Up @@ -156,6 +158,41 @@ def add_uid(results_df: pl.DataFrame, submission_df: pl.DataFrame) -> pl.DataFra
return results_df


def validate_lazy_frame(
lf: pl.LazyFrame,
context: dict[str, str] | None = None,
batch_size: int = 50000,
batch_count: int = 1,
max_errors=1000000,
):

has_syntax_errors = False
syntax_schema = get_phase_1_schema_for_lei(context)
syntax_checks = [check for col_schema in syntax_schema.columns.values() for check in col_schema.checks]

logic_schema = get_phase_2_schema_for_lei(context)
logic_checks = [check for col_schema in logic_schema.columns.values() for check in col_schema.checks]

all_uids = []

for validation_results, uids in validate_lazy_chunks(
syntax_schema, lf, batch_size, batch_count, max_errors, syntax_checks
):
all_uids.extend(uids)
# validate, and therefore validate_chunks, can return an empty dataframe for findings
if not validation_results.findings.is_empty():
has_syntax_errors = True
yield validation_results

if not has_syntax_errors:
yield validate_register_level(context, all_uids)

for validation_results, _ in validate_lazy_chunks(
logic_schema, lf, batch_size, batch_count, max_errors, logic_checks
):
yield validation_results


# This function is a Generator, and will yield the results of each batch of processing, along with the
# phase (SYNTACTICAL/LOGICAL) that the findings were found. Callers of this function will want to
# store or concat each iteration of findings
Expand Down Expand Up @@ -187,15 +224,7 @@ def validate_batch_csv(
yield validation_results

if not has_syntax_errors:
register_schema = get_register_schema(context)
validation_results = validate(register_schema, pl.DataFrame({"uid": all_uids}), 0, True)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings,
ValidationPhase.LOGICAL.value,
[check for col_schema in register_schema.columns.values() for check in col_schema.checks],
)
yield validation_results
yield validate_register_level(context, all_uids)

for validation_results, _ in validate_chunks(
logic_schema, real_path, batch_size, batch_count, max_errors, logic_checks
Expand All @@ -206,6 +235,18 @@ def validate_batch_csv(
shutil.rmtree("/tmp/s3")


def validate_register_level(context: Dict[str, str] | None, all_uids: List[str]):
register_schema = get_register_schema(context)
validation_results = validate(register_schema, pl.DataFrame({"uid": all_uids}), 0, True)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings,
ValidationPhase.LOGICAL.value,
[check for col_schema in register_schema.columns.values() for check in col_schema.checks],
)
return validation_results


# Reads in a path to a csv in batches, using batch_size to determine number of rows to read into the buffer,
# and batch_count to determine how many batches to process in parallel. Performance testing for large files
# shows 50K batch_size with 1 batch_count to be a nice balance of speed and resource utilization. Increasing
Expand All @@ -219,22 +260,42 @@ def validate_chunks(schema, path, batch_size, batch_count, max_errors, checks):
row_start = 0
while batches:
df = pl.concat(batches)
validation_results = validate(schema, df, row_start, process_errors)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings, validation_results.phase.value, checks
)

total_count += validation_results.findings.height
validation_results, total_count, process_errors = validate_chunk(
schema, df, total_count, row_start, max_errors, process_errors, checks
)
row_start += df.height
batches = reader.next_batches(batch_count)
yield validation_results, df["uid"].to_list()

if total_count > max_errors and process_errors:
process_errors = False
head_count = validation_results.findings.height - (total_count - max_errors)
validation_results.findings = validation_results.findings.head(head_count)

def validate_lazy_chunks(schema, lf: pl.LazyFrame, batch_size: int, batch_count, max_errors, checks):
process_errors = True
total_count = 0
row_start = 0
df = lf.slice(row_start, batch_size).collect()
while df.height:
validation_results, total_count, process_errors = validate_chunk(
schema, df, total_count, row_start, max_errors, process_errors, checks
)
row_start += df.height
batches = reader.next_batches(batch_count)
yield validation_results, df["uid"].to_list()
df = lf.slice(row_start, batch_size).collect()


def validate_chunk(schema, df, total_count, row_start, max_errors, process_errors, checks):
validation_results = validate(schema, df, row_start, process_errors)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings, validation_results.phase.value, checks
)

total_count += validation_results.findings.height

if total_count > max_errors and process_errors:
process_errors = False
head_count = validation_results.findings.height - (total_count - max_errors)
validation_results.findings = validation_results.findings.head(head_count)
return validation_results, total_count, process_errors


def get_real_file_path(path):
Expand Down

0 comments on commit 402064c

Please sign in to comment.