diff --git a/src/regtech_data_validator/validator.py b/src/regtech_data_validator/validator.py index 917f62a..a32c309 100644 --- a/src/regtech_data_validator/validator.py +++ b/src/regtech_data_validator/validator.py @@ -121,6 +121,7 @@ def validate( if check_output is not None: # Filter data not associated with failed Check, and update index for merging with findings_df + check_output = check_output.with_columns(pl.col('index').add(row_start)) failed_records_df = _filter_valid_records(submission_df, check_output, fields) failed_record_fields_df = _records_to_fields(failed_records_df) findings = _add_validation_metadata(failed_record_fields_df, check) @@ -133,16 +134,16 @@ def validate( if check_findings: findings_df = pl.concat(check_findings) - updated_df = add_uid(findings_df, submission_df) + updated_df = add_uid(findings_df, submission_df, row_start) return updated_df # Add the uid for the record throwing the error/warning to the error dataframe -def add_uid(results_df: pl.DataFrame, submission_df: pl.DataFrame) -> pl.DataFrame: +def add_uid(results_df: pl.DataFrame, submission_df: pl.DataFrame, offset: int) -> pl.DataFrame: if results_df.is_empty(): return results_df - uid_records = results_df['record_no'] - 1 + uid_records = results_df['record_no'] - 1 - offset results_df = results_df.with_columns(submission_df['uid'].gather(uid_records).alias('uid')) return results_df diff --git a/tests/test_sample_data.py b/tests/test_sample_data.py index 3033a50..cac9cf6 100644 --- a/tests/test_sample_data.py +++ b/tests/test_sample_data.py @@ -44,7 +44,8 @@ def test_all_logic_errors(self): vresults = [] for vresult in validate_batch_csv(ALL_LOGIC_ERRORS): vresults.append(vresult) - + # 3 phases + assert len(vresults) == 3 results = pl.concat([vr.findings for vr in vresults], how="diagonal") logic_schema = get_phase_2_schema_for_lei() @@ -85,3 +86,29 @@ def test_all_logic_warnings(self): # check that the findings validation_id Series contains at least 1 of every logic warning check id assert len(set(results['validation_id'].to_list()).difference(set(logic_checks))) == 0 assert results.select(pl.col('phase').eq(ValidationPhase.LOGICAL.value).all()).item() + + def test_all_logic_errors_batched(self): + vresults = [] + for vresult in validate_batch_csv(ALL_LOGIC_ERRORS, batch_size=3): + vresults.append(vresult) + # 3 phases with 3 batches + assert len(vresults) == 9 + results = pl.concat([vr.findings for vr in vresults], how="diagonal") + + logic_schema = get_phase_2_schema_for_lei() + register_schema = get_register_schema() + logic_checks = [ + check.title + for col_schema in logic_schema.columns.values() + for check in col_schema.checks + if check.severity == Severity.ERROR + ] + logic_checks.extend( + [check.title for col_schema in register_schema.columns.values() for check in col_schema.checks] + ) + + results = results.filter(pl.col('validation_type') == 'Error') + + # check that the findings validation_id Series contains at least 1 of every logic error check id + assert len(set(results['validation_id'].to_list()).difference(set(logic_checks))) == 0 + assert results.select(pl.col('phase').eq(ValidationPhase.LOGICAL.value).all()).item() \ No newline at end of file