From de26c42174fcab9651e1c7bd15541243459d0ac7 Mon Sep 17 00:00:00 2001 From: lchen <73617864+lchen-2101@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:36:42 -0800 Subject: [PATCH] fix: correctly offset the index for batched validation --- src/regtech_data_validator/validator.py | 7 +++--- tests/test_sample_data.py | 29 ++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/regtech_data_validator/validator.py b/src/regtech_data_validator/validator.py index 917f62a..a32c309 100644 --- a/src/regtech_data_validator/validator.py +++ b/src/regtech_data_validator/validator.py @@ -121,6 +121,7 @@ def validate( if check_output is not None: # Filter data not associated with failed Check, and update index for merging with findings_df + check_output = check_output.with_columns(pl.col('index').add(row_start)) failed_records_df = _filter_valid_records(submission_df, check_output, fields) failed_record_fields_df = _records_to_fields(failed_records_df) findings = _add_validation_metadata(failed_record_fields_df, check) @@ -133,16 +134,16 @@ def validate( if check_findings: findings_df = pl.concat(check_findings) - updated_df = add_uid(findings_df, submission_df) + updated_df = add_uid(findings_df, submission_df, row_start) return updated_df # Add the uid for the record throwing the error/warning to the error dataframe -def add_uid(results_df: pl.DataFrame, submission_df: pl.DataFrame) -> pl.DataFrame: +def add_uid(results_df: pl.DataFrame, submission_df: pl.DataFrame, offset: int) -> pl.DataFrame: if results_df.is_empty(): return results_df - uid_records = results_df['record_no'] - 1 + uid_records = results_df['record_no'] - 1 - offset results_df = results_df.with_columns(submission_df['uid'].gather(uid_records).alias('uid')) return results_df diff --git a/tests/test_sample_data.py b/tests/test_sample_data.py index 3033a50..cac9cf6 100644 --- a/tests/test_sample_data.py +++ b/tests/test_sample_data.py @@ -44,7 +44,8 @@ def test_all_logic_errors(self): vresults = [] for vresult in validate_batch_csv(ALL_LOGIC_ERRORS): vresults.append(vresult) - + # 3 phases + assert len(vresults) == 3 results = pl.concat([vr.findings for vr in vresults], how="diagonal") logic_schema = get_phase_2_schema_for_lei() @@ -85,3 +86,29 @@ def test_all_logic_warnings(self): # check that the findings validation_id Series contains at least 1 of every logic warning check id assert len(set(results['validation_id'].to_list()).difference(set(logic_checks))) == 0 assert results.select(pl.col('phase').eq(ValidationPhase.LOGICAL.value).all()).item() + + def test_all_logic_errors_batched(self): + vresults = [] + for vresult in validate_batch_csv(ALL_LOGIC_ERRORS, batch_size=3): + vresults.append(vresult) + # 3 phases with 3 batches + assert len(vresults) == 9 + results = pl.concat([vr.findings for vr in vresults], how="diagonal") + + logic_schema = get_phase_2_schema_for_lei() + register_schema = get_register_schema() + logic_checks = [ + check.title + for col_schema in logic_schema.columns.values() + for check in col_schema.checks + if check.severity == Severity.ERROR + ] + logic_checks.extend( + [check.title for col_schema in register_schema.columns.values() for check in col_schema.checks] + ) + + results = results.filter(pl.col('validation_type') == 'Error') + + # check that the findings validation_id Series contains at least 1 of every logic error check id + assert len(set(results['validation_id'].to_list()).difference(set(logic_checks))) == 0 + assert results.select(pl.col('phase').eq(ValidationPhase.LOGICAL.value).all()).item() \ No newline at end of file