Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FailedRowProcessor support in soda-spark #114

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/sodaspark/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyspark.sql import types as T # noqa: N812
from sodasql.common.yaml_helper import YamlHelper
from sodasql.dialects.spark_dialect import SparkDialect
from sodasql.scan.failed_rows_processor import FailedRowsProcessor
from sodasql.scan.file_system import FileSystemSingleton
from sodasql.scan.measurement import Measurement
from sodasql.scan.scan import Scan
Expand Down Expand Up @@ -255,6 +256,7 @@ def create_scan(
warehouse_name: str = "sodaspark",
soda_server_client: SodaServerClient | None = None,
time: str | None = None,
failed_rows_processor: FailedRowsProcessor | None = None,
) -> Scan:
"""
Create a scan object.
Expand All @@ -263,11 +265,16 @@ def create_scan(
----------
scan_yml : ScanYml
The scan yml.
variables: variables to be substituted in scan yml
variables: Optional[dict] (default: None)
variables to be substituted in scan yml
warehouse_name: Optional[str] (default: sodapsark)
The name of the warehouse
soda_server_client : Optional[SodaServerClient] (default : None)
A soda server client.
time: Optional[str] (default: None)
Timestamp date in ISO8601 format. If None, use datatime.now() in ISO8601 format.
failed_rows_processor: Optional[FailedRowsProcessor] (default: None)
A FailedRowsProcessor implementation

Returns
-------
Expand All @@ -285,6 +292,7 @@ def create_scan(
soda_server_client=soda_server_client,
variables=variables,
time=time,
failed_rows_processor=failed_rows_processor,
)
return scan

Expand Down Expand Up @@ -430,6 +438,7 @@ def execute(
soda_server_client: SodaServerClient | None = None,
as_frames: bool | None = False,
time: str | None = None,
failed_rows_processor: FailedRowsProcessor | None = None,
) -> ScanResult:
"""
Execute a scan on a data frame.
Expand All @@ -442,12 +451,16 @@ def execute(
The data frame to be scanned.
variables: Optional[dict] (default : None)
Variables to be substituted in scan yml
warehouse_name: Optional[str] (default: sodapsark)
The name of the warehouse
soda_server_client : Optional[SodaServerClient] (default : None)
A soda server client.
as_frames : bool (default : False)
Flag to return results in Dataframe
time: str (default : None)
Timestamp date in ISO8601 format at the start of a scan
failed_rows_processor: Optional[FailedRowsProcessor] (default: None)
A FailedRowsProcessor implementation

Returns
-------
Expand All @@ -463,6 +476,7 @@ def execute(
soda_server_client=soda_server_client,
time=time,
warehouse_name=warehouse_name,
failed_rows_processor=failed_rows_processor,
)
scan.execute()

Expand Down
75 changes: 75 additions & 0 deletions tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from typing import BinaryIO

import pytest
from _pytest.capture import CaptureFixture
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import types as T # noqa: N812
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
from sodasql.dialects.spark_dialect import SparkDialect
from sodasql.scan.failed_rows_processor import FailedRowsProcessor
from sodasql.scan.group_value import GroupValue
from sodasql.scan.measurement import Measurement
from sodasql.scan.scan_error import TestExecutionScanError
Expand Down Expand Up @@ -183,6 +186,22 @@ def df(spark_session: SparkSession) -> DataFrame:
return df


class InMemoryFailedRowProcessor(FailedRowsProcessor):
def process(self, context: dict) -> dict:

try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This try except does not do anything, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct. I was just following the pattern I found in this doc 😅 Just changed the except to throw the exception 🤔

print(context)
except Exception:
raise Exception

return {"message": "All failed rows were printed in your terminal"}


@pytest.fixture
def failed_rows_processor() -> FailedRowsProcessor:
return InMemoryFailedRowProcessor()


def test_create_scan_yml_table_name_is_demodata(
scan_definition: str,
) -> None:
Expand Down Expand Up @@ -507,3 +526,59 @@ def test_scan_execute_return_as_data_frame(
(scan_result[1].count(), len(scan_result[1].columns)),
(scan_result[2].count(), len(scan_result[2].columns)),
)


def test_failed_row_processor_return_correct_values(
spark_session: SparkSession,
failed_rows_processor: FailedRowsProcessor,
capsys: CaptureFixture,
) -> None:
vijaykiran marked this conversation as resolved.
Show resolved Hide resolved

expected_output = [
vijaykiran marked this conversation as resolved.
Show resolved Hide resolved
"{'sample_name': 'dataset', 'column_name': None, 'test_ids': None, "
"'sample_columns': [{'name': 'id', 'type': 'string'}, {'name': 'number', "
"'type': 'int'}], 'sample_rows': [['1', 100], ['2', 200], ['3', None], ['4', "
"400]], 'sample_description': 'my_table.sample', 'total_row_count': 4}",
"{'sample_name': 'missing', 'column_name': 'number', 'test_ids': "
'[\'{"column":"number","expression":"missing_count == 0"}\'], '
"'sample_columns': [{'name': 'id', 'type': 'string'}, {'name': 'number', "
"'type': 'int'}], 'sample_rows': [['3', None]], 'sample_description': "
"'my_table.number.missing', 'total_row_count': 1}",
"",
]

data = [("1", 100), ("2", 200), ("3", None), ("4", 400)]

schema = StructType(
[
StructField("id", StringType(), True),
StructField("number", IntegerType(), True),
]
)

df = spark_session.createDataFrame(data=data, schema=schema)

scan_definition = """
table_name: my_table
metric_groups:
- all
samples:
table_limit: 5
failed_limit: 5
tests:
- row_count > 0
columns:
number:
tests:
- duplicate_count == 0
- missing_count == 0
"""

scan.execute(
scan_definition=scan_definition,
df=df,
failed_rows_processor=failed_rows_processor,
)

out, err = capsys.readouterr()
assert expected_output == out.split("\n")