Skip to content

Commit

Permalink
fix: correctly offset the index for batched validation
Browse files Browse the repository at this point in the history
  • Loading branch information
lchen-2101 committed Nov 14, 2024
1 parent aa141a6 commit de26c42
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
7 changes: 4 additions & 3 deletions src/regtech_data_validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def validate(

if check_output is not None:
# Filter data not associated with failed Check, and update index for merging with findings_df
check_output = check_output.with_columns(pl.col('index').add(row_start))
failed_records_df = _filter_valid_records(submission_df, check_output, fields)
failed_record_fields_df = _records_to_fields(failed_records_df)
findings = _add_validation_metadata(failed_record_fields_df, check)
Expand All @@ -133,16 +134,16 @@ def validate(
if check_findings:
findings_df = pl.concat(check_findings)

updated_df = add_uid(findings_df, submission_df)
updated_df = add_uid(findings_df, submission_df, row_start)
return updated_df


# Add the uid for the record throwing the error/warning to the error dataframe
def add_uid(results_df: pl.DataFrame, submission_df: pl.DataFrame) -> pl.DataFrame:
def add_uid(results_df: pl.DataFrame, submission_df: pl.DataFrame, offset: int) -> pl.DataFrame:
if results_df.is_empty():
return results_df

uid_records = results_df['record_no'] - 1
uid_records = results_df['record_no'] - 1 - offset
results_df = results_df.with_columns(submission_df['uid'].gather(uid_records).alias('uid'))
return results_df

Expand Down
29 changes: 28 additions & 1 deletion tests/test_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def test_all_logic_errors(self):
vresults = []
for vresult in validate_batch_csv(ALL_LOGIC_ERRORS):
vresults.append(vresult)

# 3 phases
assert len(vresults) == 3
results = pl.concat([vr.findings for vr in vresults], how="diagonal")

logic_schema = get_phase_2_schema_for_lei()
Expand Down Expand Up @@ -85,3 +86,29 @@ def test_all_logic_warnings(self):
# check that the findings validation_id Series contains at least 1 of every logic warning check id
assert len(set(results['validation_id'].to_list()).difference(set(logic_checks))) == 0
assert results.select(pl.col('phase').eq(ValidationPhase.LOGICAL.value).all()).item()

def test_all_logic_errors_batched(self):
vresults = []
for vresult in validate_batch_csv(ALL_LOGIC_ERRORS, batch_size=3):
vresults.append(vresult)
# 3 phases with 3 batches
assert len(vresults) == 9
results = pl.concat([vr.findings for vr in vresults], how="diagonal")

logic_schema = get_phase_2_schema_for_lei()
register_schema = get_register_schema()
logic_checks = [
check.title
for col_schema in logic_schema.columns.values()
for check in col_schema.checks
if check.severity == Severity.ERROR
]
logic_checks.extend(
[check.title for col_schema in register_schema.columns.values() for check in col_schema.checks]
)

results = results.filter(pl.col('validation_type') == 'Error')

# check that the findings validation_id Series contains at least 1 of every logic error check id
assert len(set(results['validation_id'].to_list()).difference(set(logic_checks))) == 0
assert results.select(pl.col('phase').eq(ValidationPhase.LOGICAL.value).all()).item()

0 comments on commit de26c42

Please sign in to comment.