Skip to content

Commit

Permalink
black and ruff fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
hkeeler committed Oct 17, 2023
1 parent b732419 commit 3b18289
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 40 deletions.
20 changes: 5 additions & 15 deletions regtech_data_validator/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from pandera.backends.base import BaseCheckBackend
from pandera.backends.pandas.checks import PandasCheckBackend


class Severity(StrEnum):
ERROR = 'error'
WARNING = 'warning'


class SBLCheck(Check):
"""
A Pandera.Check subclasss that requires a `name` and an `id` be
Expand All @@ -23,13 +25,7 @@ class SBLCheck(Check):
SBLWarningCheck subclasses below.
"""

def __init__(self,
check_fn: Callable,
id: str,
name: str,
description: str,
severity: Severity,
**check_kwargs):
def __init__(self, check_fn: Callable, id: str, name: str, description: str, severity: Severity, **check_kwargs):
"""
Subclass of Pandera's `Check`, with special handling for severity level
Args:
Expand All @@ -43,15 +39,9 @@ def __init__(self,

self.severity = severity

super().__init__(
check_fn,
title=id,
name=name,
description=description,
**check_kwargs
)
super().__init__(check_fn, title=id, name=name, description=description, **check_kwargs)

@classmethod
def get_backend(cls, check_obj: Any) -> Type[BaseCheckBackend]:
"""Assume Pandas DataFrame and return PandasCheckBackend"""
return PandasCheckBackend
return PandasCheckBackend
15 changes: 7 additions & 8 deletions regtech_data_validator/create_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@
phase_2_template = get_template()


def get_schema_by_phase_for_lei(template: dict, phase: str, lei: str|None = None):
def get_schema_by_phase_for_lei(template: dict, phase: str, lei: str | None = None):
for column in get_phase_1_and_2_validations_for_lei(lei):
validations = get_phase_1_and_2_validations_for_lei(lei)[column]
template[column].checks = validations[phase]
return DataFrameSchema(template)


def get_phase_1_schema_for_lei(lei: str|None = None):
def get_phase_1_schema_for_lei(lei: str | None = None):
return get_schema_by_phase_for_lei(phase_1_template, "phase_1", lei)


def get_phase_2_schema_for_lei(lei: str|None = None):
def get_phase_2_schema_for_lei(lei: str | None = None):
return get_schema_by_phase_for_lei(phase_2_template, "phase_2", lei)


Expand All @@ -46,11 +46,10 @@ def validate(schema: DataFrameSchema, df: pd.DataFrame) -> list[dict]:
try:
schema(df, lazy=True)
except SchemaErrors as err:

# WARN: SchemaErrors.schema_errors is supposed to be of type
# list[dict[str,Any]], but it's actually of type SchemaError
schema_error: SchemaError
for schema_error in err.schema_errors: # type: ignore
for schema_error in err.schema_errors: # type: ignore
check = schema_error.check
column_name = schema_error.schema.name

Expand All @@ -63,7 +62,7 @@ def validate(schema: DataFrameSchema, df: pd.DataFrame) -> list[dict]:
raise RuntimeError(
f'Check {check} type on {column_name} column not supported. Must be of type {SBLCheck}'
) from schema_error

fields: list[str] = [column_name]

if check.groupby:
Expand Down Expand Up @@ -110,7 +109,7 @@ def validate(schema: DataFrameSchema, df: pd.DataFrame) -> list[dict]:
return findings


def validate_phases(df: pd.DataFrame, lei: str|None = None) -> list:
def validate_phases(df: pd.DataFrame, lei: str | None = None) -> list:
phase1_findings = validate(get_phase_1_schema_for_lei(lei), df)
if phase1_findings:
return phase1_findings
Expand All @@ -119,4 +118,4 @@ def validate_phases(df: pd.DataFrame, lei: str|None = None) -> list:
if phase2_findings:
return phase2_findings
else:
return [{"response": "No validations errors or warnings"}]
return [{"response": "No validations errors or warnings"}]
2 changes: 1 addition & 1 deletion regtech_data_validator/global_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


# global variable for NAICS codes
naics_codes: dict[str,str] = {}
naics_codes: dict[str, str] = {}
naics_file_path = files('regtech_data_validator.data.naics').joinpath('2022_codes.csv')

with naics_file_path.open('r') as f:
Expand Down
7 changes: 4 additions & 3 deletions regtech_data_validator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def csv_to_df(path: str) -> pd.DataFrame:
return pd.read_csv(path, dtype=str, na_filter=False)


def run_validation_on_df(df: pd.DataFrame, lei: str|None) -> None:
def run_validation_on_df(df: pd.DataFrame, lei: str | None) -> None:
"""
Run validation on the supplied dataframe and print a report to
the terminal.
Expand All @@ -31,7 +31,7 @@ def run_validation_on_df(df: pd.DataFrame, lei: str|None) -> None:

def main():
csv_path = None
lei: str|None = None
lei: str | None = None
if len(sys.argv) == 1:
raise ValueError("csv_path arg not provided")
elif len(sys.argv) == 2:
Expand All @@ -45,5 +45,6 @@ def main():
df = csv_to_df(csv_path)
run_validation_on_df(df, lei)


if __name__ == "__main__":
main()
main()
9 changes: 4 additions & 5 deletions regtech_data_validator/phase_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from regtech_data_validator.checks import SBLCheck, Severity


def get_phase_1_and_2_validations_for_lei(lei: str|None = None):
def get_phase_1_and_2_validations_for_lei(lei: str | None = None):
return {
"uid": {
"phase_1": [
Expand Down Expand Up @@ -1446,8 +1446,7 @@ def get_phase_1_and_2_validations_for_lei(lei: str|None = None):
id="E0720",
name="naics_code_flag.invalid_enum_value",
description=(
"'North American Industry Classification System (NAICS) code: NP flag'"
"must equal 900 or 988."
"'North American Industry Classification System (NAICS) code: NP flag'must equal 900 or 988."
),
severity=Severity.ERROR,
element_wise=True,
Expand Down Expand Up @@ -3305,7 +3304,7 @@ def get_phase_1_and_2_validations_for_lei(lei: str|None = None):
" field for other Pacific Islander race' must"
" not exceed 300 characters in length."
),
severity=Severity.ERROR
severity=Severity.ERROR,
),
],
"phase_2": [
Expand Down Expand Up @@ -3383,4 +3382,4 @@ def get_phase_1_and_2_validations_for_lei(lei: str|None = None):
),
],
},
}
}
7 changes: 2 additions & 5 deletions regtech_data_validator/schema_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,7 @@
),
"pricing_mca_addcost": Column(
str,
title=(
"Field 31: MCA/sales-based: additional cost for merchant cash "
"advances or other sales-based financing"
),
title="Field 31: MCA/sales-based: additional cost for merchant cash advances or other sales-based financing",
checks=[],
),
"pricing_prepenalty_allowed": Column(
Expand Down Expand Up @@ -449,4 +446,4 @@ def get_template() -> Dict:
cause absolute havoc in a program and it's proactically impossible
to debug."""

return deepcopy(_schema_template)
return deepcopy(_schema_template)
2 changes: 0 additions & 2 deletions tests/test_global_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from regtech_data_validator import global_data


Expand Down
3 changes: 2 additions & 1 deletion tests/test_schema_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from regtech_data_validator.create_schemas import (
get_phase_1_schema_for_lei,
get_phase_2_schema_for_lei,
validate, validate_phases
validate,
validate_phases,
)


Expand Down

0 comments on commit 3b18289

Please sign in to comment.