Skip to content

Commit

Permalink
[Feature] updating code to work for retries if the save table fails (#61
Browse files Browse the repository at this point in the history
)

* updating code to work for retries if the save table fails. Alter table is not done for stats table

* Removing retries as it's not needed
  • Loading branch information
asingamaneni authored Dec 7, 2023
1 parent 5e6c601 commit 1c87e21
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 23 deletions.
6 changes: 3 additions & 3 deletions spark_expectations/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,23 @@ class SparkExpectationsMiscException(Exception):

class SparkExpectationsSlackNotificationException(Exception):
"""
Throw this exception when spark expectations encounters miscellaneous exceptions
Throw this exception when spark expectations encounters exceptions while sending Slack notifications
"""

pass


class SparkExpectationsTeamsNotificationException(Exception):
"""
Throw this exception when spark expectations encounters miscellaneous exceptions
Throw this exception when spark expectations encounters exceptions while sending Teams notifications
"""

pass


class SparkExpectationsEmailException(Exception):
"""
Throw this exception when spark expectations encounters miscellaneous exceptions
Throw this exception when spark expectations encounters exceptions while sending email notifications
"""

pass
Expand Down
30 changes: 26 additions & 4 deletions spark_expectations/sinks/utils/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,36 @@ def save_df_as_table(
if config["options"] is not None and config["options"] != {}:
_df_writer = _df_writer.options(**config["options"])

_log.info("Writing records to table: %s", table_name)

if config["format"] == "bigquery":
_df_writer.option("table", table_name).save()
else:
_df_writer.saveAsTable(name=table_name)
self.spark.sql(
f"ALTER TABLE {table_name} SET TBLPROPERTIES ('product_id' = '{self._context.product_id}')"
)
_log.info("finished writing records to table: %s,", table_name)
_log.info("finished writing records to table: %s,", table_name)
if not stats_table:
# Fetch table properties
table_properties = self.spark.sql(
f"SHOW TBLPROPERTIES {table_name}"
).collect()
table_properties_dict = {
row["key"]: row["value"] for row in table_properties
}

# Set product_id in table properties
if (
table_properties_dict.get("product_id") is None
or table_properties_dict.get("product_id")
!= self._context.product_id
):
_log.info(
"product_id is not set for table %s in tableproperties, setting it now",
table_name,
)
self.spark.sql(
f"ALTER TABLE {table_name} SET TBLPROPERTIES ('product_id' = "
f"'{self._context.product_id}')"
)

except Exception as e:
raise SparkExpectationsUserInputOrConfigInvalidException(
Expand Down
49 changes: 33 additions & 16 deletions tests/sinks/utils/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import unittest.mock
from unittest.mock import patch, Mock

import pyspark.sql
import pytest
from pyspark.sql.functions import col
from pyspark.sql.functions import lit, to_timestamp
Expand Down Expand Up @@ -57,7 +56,7 @@ def fixture_writer():
setattr(mock_context, "get_run_id_name", "meta_dq_run_id")
setattr(mock_context, "get_run_date_name", "meta_dq_run_date")
mock_context.spark = spark
mock_context.product_id='product1'
mock_context.product_id = 'product1'

# Create an instance of the class and set the product_id
return SparkExpectationsWriter(mock_context)
Expand Down Expand Up @@ -160,9 +159,11 @@ def fixture_expected_dq_dataset():
@pytest.mark.parametrize('table_name, options, expected_count',
[('employee_table',
{'mode': 'overwrite', 'partitionBy': ['department'], "format": "parquet",
'bucketBy': {'numBuckets':2,'colName':'business_unit'}, 'sortBy': ["eeid"], 'options': {"overwriteSchema": "true", "mergeSchema": "true"}}, 1000),
'bucketBy': {'numBuckets': 2, 'colName': 'business_unit'}, 'sortBy': ["eeid"],
'options': {"overwriteSchema": "true", "mergeSchema": "true"}}, 1000),
('employee_table',
{'mode': 'append', "format": "delta", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [], 'options': {"mergeSchema": "true"}},
{'mode': 'append', "format": "delta", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [],
'options': {"mergeSchema": "true"}},
1000)
])
def test_save_df_as_table(table_name,
Expand All @@ -175,17 +176,31 @@ def test_save_df_as_table(table_name,

assert expected_count == spark.sql(f"select * from {table_name}").count()

# Assert
# _spark_set.assert_called_with('spark.sql.session.timeZone', 'Etc/UTC')
# Fetch table properties
table_properties = spark.sql(f"SHOW TBLPROPERTIES {table_name}").collect()
table_properties_dict = {row["key"]: row["value"] for row in table_properties}

# Check 'product_id' property
assert table_properties_dict.get("product_id") == _fixture_writer._context.product_id

spark.sql(f"drop table if exists {table_name}")
spark.sql(f"drop table if exists {table_name}_stats")
spark.sql(f"drop table if exists {table_name}_error")

_fixture_writer.save_df_as_table(_fixture_employee, table_name, options, True)
# Fetch table properties
table_properties = spark.sql(f"SHOW TBLPROPERTIES {table_name}").collect()
table_properties_dict = {row["key"]: row["value"] for row in table_properties}
assert table_properties_dict.get("product_id") is None

@patch('pyspark.sql.DataFrameWriter.save', autospec=True, spec_set=True)
def test_save_df_to_table_bq(save, _fixture_writer, _fixture_employee, _fixture_create_employee_table):
def test_save_df_to_table_bq(save, _fixture_writer, _fixture_employee, _fixture_create_employee_table):
_fixture_writer.save_df_as_table(_fixture_employee, 'employee_table', {'mode': 'overwrite', 'format': 'bigquery',
'partitionBy':[], 'bucketBy': {}, 'sortBy':[], 'options':{}})
'partitionBy': [], 'bucketBy': {},
'sortBy': [], 'options': {}})
save.assert_called_once_with(unittest.mock.ANY)



@pytest.mark.parametrize('table_name, options',
[('employee_table', {'mode': 'overwrite',
'partitionBy': ['department'],
Expand Down Expand Up @@ -439,7 +454,8 @@ def test_write_df_to_table(save_df_as_table,
"output_percentage": 0.0,
"success_percentage": 0.0,
"error_percentage": 100.0,
}, {'mode': 'append', "format": "bigquery", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [], 'options': {"mergeSchema": "true"}})
}, {'mode': 'append', "format": "bigquery", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [],
'options': {"mergeSchema": "true"}})
])
def test_write_error_stats(input_record,
expected_result,
Expand Down Expand Up @@ -502,8 +518,10 @@ def test_write_error_stats(input_record,
setattr(_mock_context, 'get_dq_stats_table_name', 'test_dq_stats_table')

if writer_config is None:
setattr(_mock_context, "_stats_table_writer_config", WrappedDataFrameWriter().mode("overwrite").format("delta").build())
setattr(_mock_context, 'get_stats_table_writer_config', WrappedDataFrameWriter().mode("overwrite").format("delta").build())
setattr(_mock_context, "_stats_table_writer_config",
WrappedDataFrameWriter().mode("overwrite").format("delta").build())
setattr(_mock_context, 'get_stats_table_writer_config',
WrappedDataFrameWriter().mode("overwrite").format("delta").build())
else:
setattr(_mock_context, "_stats_table_writer_config", writer_config)
setattr(_mock_context, 'get_stats_table_writer_config', writer_config)
Expand Down Expand Up @@ -554,9 +572,6 @@ def test_write_error_stats(input_record,
"to_json(struct(*)) AS value").collect()





@pytest.mark.parametrize('table_name, rule_type',
[('test_error_table',
'row_dq'
Expand Down Expand Up @@ -608,7 +623,8 @@ def test_write_error_records_final_dependent(save_df_as_table,
.withColumn("meta_dq_run_date", lit("2022-12-27 10:39:44")) \
.orderBy("id").collect()
assert save_df_args[0][2] == table_name
save_df_as_table.assert_called_once_with(_fixture_writer, save_df_args[0][1], table_name, _fixture_writer._context.get_target_and_error_table_writer_config)
save_df_as_table.assert_called_once_with(_fixture_writer, save_df_args[0][1], table_name,
_fixture_writer._context.get_target_and_error_table_writer_config)


@pytest.mark.parametrize("test_data, expected_result", [
Expand Down Expand Up @@ -647,6 +663,7 @@ def test_generate_summarised_row_dq_res(test_data, expected_result):
result = context.get_summarised_row_dq_res
assert result == expected_result


@pytest.mark.parametrize('dq_rules, summarised_row_dq, expected_result',
[
(
Expand Down

0 comments on commit 1c87e21

Please sign in to comment.