Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/pip/ruff-0.7.3
Browse files Browse the repository at this point in the history
  • Loading branch information
jcadam14 authored Nov 13, 2024
2 parents c225a78 + aa141a6 commit 7a1fbd0
Showing 1 changed file with 58 additions and 85 deletions.
143 changes: 58 additions & 85 deletions src/regtech_data_validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def validate(
pd.DataFrame containing validation results data
"""
findings_df: pl.DataFrame = pl.DataFrame()
error_counts = warning_counts = Counts()

try:
# since polars dataframes don't normally have an index column, add it, so that we can match
Expand All @@ -94,7 +93,6 @@ def validate(
# `list[dict[str,Any]]`, but it's actually of type `SchemaError`
schema_error: SchemaError

error_counts, warning_counts = get_scope_counts(err.schema_errors)
if process_errors:
for schema_error in err.schema_errors:
check = schema_error.check
Expand Down Expand Up @@ -136,14 +134,7 @@ def validate(
findings_df = pl.concat(check_findings)

updated_df = add_uid(findings_df, submission_df)
results = ValidationResults(
error_counts=error_counts,
warning_counts=warning_counts,
is_valid=((error_counts.total_count + warning_counts.total_count) == 0),
findings=updated_df,
phase=schema.name,
)
return results
return updated_df


# Add the uid for the record throwing the error/warning to the error dataframe
Expand Down Expand Up @@ -189,13 +180,21 @@ def validate_batch_csv(
if not has_syntax_errors:
register_schema = get_register_schema(context)
validation_results = validate(register_schema, pl.DataFrame({"uid": all_uids}), 0, True)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings,
if not validation_results.is_empty():
validation_results = format_findings(
validation_results,
ValidationPhase.LOGICAL.value,
[check for col_schema in register_schema.columns.values() for check in col_schema.checks],
)
yield validation_results
error_counts, warning_counts = get_scope_counts(validation_results)
results = ValidationResults(
error_counts=error_counts,
warning_counts=warning_counts,
is_valid=((error_counts.total_count + warning_counts.total_count) == 0),
findings=validation_results,
phase=register_schema.name,
)
yield results

for validation_results, _ in validate_chunks(
logic_schema, real_path, batch_size, batch_count, max_errors, logic_checks
Expand All @@ -220,21 +219,29 @@ def validate_chunks(schema, path, batch_size, batch_count, max_errors, checks):
while batches:
df = pl.concat(batches)
validation_results = validate(schema, df, row_start, process_errors)
if not validation_results.findings.is_empty():
validation_results.findings = format_findings(
validation_results.findings, validation_results.phase.value, checks
)
if not validation_results.is_empty():

total_count += validation_results.findings.height
validation_results = format_findings(validation_results, schema.name.value, checks)

error_counts, warning_counts = get_scope_counts(validation_results)
results = ValidationResults(
error_counts=error_counts,
warning_counts=warning_counts,
is_valid=((error_counts.total_count + warning_counts.total_count) == 0),
findings=validation_results,
phase=schema.name,
)

total_count += results.findings.height

if total_count > max_errors and process_errors:
process_errors = False
head_count = validation_results.findings.height - (total_count - max_errors)
validation_results.findings = validation_results.findings.head(head_count)
head_count = results.findings.height - (total_count - max_errors)
results.findings = results.findings.head(head_count)

row_start += df.height
batches = reader.next_batches(batch_count)
yield validation_results, df["uid"].to_list()
yield results, df["uid"].to_list()


def get_real_file_path(path):
Expand All @@ -256,68 +263,34 @@ def gather_errors(schema_error: SchemaError):
return schema_error


def get_scope_counts(schema_errors: list[SchemaError]):
singles = [
error for error in schema_errors if isinstance(error.check, SBLCheck) and error.check.scope == 'single-field'
]

single_errors = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in singles
if error.check.severity == Severity.ERROR
]
def get_scope_counts(error_frame: pl.DataFrame):
if not error_frame.is_empty():
single_errors = error_frame.filter(
(pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "single-field")
).height
single_warnings = error_frame.filter(
(pl.col("validation_type") == Severity.WARNING) & (pl.col("scope") == "single-field")
).height
register_errors = error_frame.filter(
(pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "register")
).height
multi_errors = error_frame.filter(
(pl.col("validation_type") == Severity.ERROR) & (pl.col("scope") == "multi-field")
).height
multi_warnings = error_frame.filter(
(pl.col("validation_type") == Severity.WARNING) & (pl.col("scope") == "multi-field")
).height

return Counts(
single_field_count=single_errors,
multi_field_count=multi_errors,
register_count=register_errors,
total_count=sum([single_errors, multi_errors, register_errors]),
), Counts(
single_field_count=single_warnings,
multi_field_count=multi_warnings,
total_count=sum([single_warnings, multi_warnings]), # There are no register-level warnings at this time
)
)
single_warnings = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in singles
if error.check.severity == Severity.WARNING
]
)
)
multi = [
error for error in schema_errors if isinstance(error.check, SBLCheck) and error.check.scope == 'multi-field'
]
multi_errors = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in multi
if error.check.severity == Severity.ERROR
]
)
)
multi_warnings = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in multi
if error.check.severity == Severity.WARNING
]
)
)

register_errors = int(
sum(
[
(error.check_output.filter(~pl.col("check_output"))).height
for error in schema_errors
if isinstance(error.check, SBLCheck) and error.check.scope == 'register'
]
)
)

return Counts(
single_field_count=single_errors,
multi_field_count=multi_errors,
register_count=register_errors,
total_count=sum([single_errors, multi_errors, register_errors]),
), Counts(
single_field_count=single_warnings,
multi_field_count=multi_warnings,
total_count=sum([single_warnings, multi_warnings]), # There are no register-level warnings at this time
)
else:
return Counts(), Counts()

0 comments on commit 7a1fbd0

Please sign in to comment.