Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle exception for invalid column names containing . #102

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion pydeequ/checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
from enum import Enum

from pyspark.sql import SparkSession

from pydeequ.scala_utils import ScalaFunction1, to_scala_seq
Expand Down Expand Up @@ -100,6 +99,14 @@ def __init__(self, spark_session: SparkSession, level: CheckLevel, description:
for constraint in self.constraints:
self.addConstraint(constraint)

@staticmethod
def _handle_invalid_column(column):
"""
:param column: column to which constraint is to be applied
"""
if "." in column:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is quite strict, perhaps simply enclosing the field with ticks will do, like f"`column.name`"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@herminio-iovio passing through ticks also does not work and returns the same error

raise ValueError("column name cannot contain .")

def addConstraints(self, constraints: list):
self.constraints.extend(constraints)
for constraint in constraints:
Expand Down Expand Up @@ -144,6 +151,7 @@ def isComplete(self, column, hint=None):
:return: isComplete self:A Check.scala object that asserts on a column completion.
"""
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.isComplete(column, hint)
return self

Expand All @@ -160,6 +168,7 @@ def hasCompleteness(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasCompleteness(column, assertion_func, hint)
print(self)
return self
Expand All @@ -172,6 +181,7 @@ def areComplete(self, columns, hint=None):
:return: areComplete self: A Check.scala object that asserts completion in the columns.
"""
hint = self._jvm.scala.Option.apply(hint)
[self._handle_invalid_column(column) for column in columns]
columns_seq = to_scala_seq(self._jvm, columns)
self._Check = self._Check.areComplete(columns_seq, hint)
return self
Expand All @@ -184,6 +194,7 @@ def haveCompleteness(self, columns, assertion, hint=None):
:param str hint: A hint that states why a constraint could have failed.
:return: haveCompleteness self: A Check.scala object that implements the assertion on the columns.
"""
[self._handle_invalid_column(column) for column in columns]
columns_seq = to_scala_seq(self._jvm, columns)
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
Expand All @@ -198,6 +209,7 @@ def areAnyComplete(self, columns, hint=None):
:return: areAnyComplete self: A Check.scala object that asserts completion in the columns.
"""
hint = self._jvm.scala.Option.apply(hint)
[self._handle_invalid_column(column) for column in columns]
columns_seq = to_scala_seq(self._jvm, columns)
self._Check = self._Check.areAnyComplete(columns_seq, hint)
return self
Expand All @@ -210,6 +222,7 @@ def haveAnyCompleteness(self, columns, assertion, hint=None):
:param str hint: A hint that states why a constraint could have failed.
:return: haveAnyCompleteness self: A Check.scala object that asserts completion in the columns.
"""
[self._handle_invalid_column(column) for column in columns]
columns_seq = to_scala_seq(self._jvm, columns)
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
Expand All @@ -225,6 +238,7 @@ def isUnique(self, column, hint=None):
:return: isUnique self: A Check.scala object that asserts uniqueness in the column.
"""
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.isUnique(column, hint)
return self

Expand Down Expand Up @@ -255,6 +269,7 @@ def hasUniqueness(self, columns, assertion, hint=None):
:return: hasUniqueness self: A Check object that asserts uniqueness in the columns.
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
[self._handle_invalid_column(column) for column in columns]
columns_seq = to_scala_seq(self._jvm, columns)
hint = self._jvm.scala.Option.apply(hint)
self._Check = self._Check.hasUniqueness(columns_seq, assertion_func, hint)
Expand All @@ -272,6 +287,7 @@ def hasDistinctness(self, columns, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
[self._handle_invalid_column(column) for column in columns]
columns_seq = to_scala_seq(self._jvm, columns)
self._Check = self._Check.hasDistinctness(columns_seq, assertion_func, hint)
return self
Expand All @@ -287,6 +303,7 @@ def hasUniqueValueRatio(self, columns, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
[self._handle_invalid_column(column) for column in columns]
columns_seq = to_scala_seq(self._jvm, columns)
self._Check = self._Check.hasUniqueValueRatio(columns_seq, assertion_func, hint)
return self
Expand All @@ -303,6 +320,7 @@ def hasNumberOfDistinctValues(self, column, assertion, binningUdf, maxBins, hint
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasNumberOfDistinctValues(column, assertion_func, binningUdf, maxBins, hint)
return self

Expand All @@ -318,6 +336,7 @@ def hasHistogramValues(self, column, assertion, binningUdf, maxBins, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasHistogramValues(column, assertion_func, binningUdf, maxBins, hint)
return self

Expand All @@ -333,6 +352,7 @@ def kllSketchSatisfies(self, column, assertion, kllParameters=None, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
params = self._jvm.scala.Option.apply(kllParameters._param if kllParameters else None)
self._Check = self._Check.kllSketchSatisfies(column, assertion_func, params, hint)
return self
Expand Down Expand Up @@ -363,6 +383,7 @@ def hasEntropy(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasEntropy(column, assertion_func, hint)
return self

Expand All @@ -379,6 +400,8 @@ def hasMutualInformation(self, columnA, columnB, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(columnA)
self._handle_invalid_column(columnB)
self._Check = self._Check.hasMutualInformation(columnA, columnB, assertion_func, hint)
return self

Expand All @@ -394,6 +417,7 @@ def hasApproxQuantile(self, column, quantile, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasApproxQuantile(column, float(quantile), assertion_func, hint)
return self

Expand All @@ -409,6 +433,7 @@ def hasMinLength(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasMinLength(column, assertion_func, hint)
return self

Expand All @@ -424,6 +449,7 @@ def hasMaxLength(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasMaxLength(column, assertion_func, hint)
return self

Expand All @@ -440,6 +466,7 @@ def hasMin(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasMin(column, assertion_func, hint)
return self

Expand All @@ -456,6 +483,7 @@ def hasMax(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasMax(column, assertion_func, hint)
return self

Expand All @@ -470,6 +498,7 @@ def hasMean(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasMean(column, assertion_func, hint)
return self

Expand All @@ -484,6 +513,7 @@ def hasSum(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasSum(column, assertion_func, hint)
return self

Expand All @@ -498,6 +528,7 @@ def hasStandardDeviation(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasStandardDeviation(column, assertion_func, hint)
return self

Expand All @@ -512,6 +543,7 @@ def hasApproxCountDistinct(self, column, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasApproxCountDistinct(column, assertion_func, hint)
return self

Expand All @@ -527,6 +559,8 @@ def hasCorrelation(self, columnA, columnB, assertion, hint=None):
"""
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(columnA)
self._handle_invalid_column(columnB)
self._Check = self._Check.hasCorrelation(columnA, columnB, assertion_func, hint)
return self

Expand Down Expand Up @@ -581,6 +615,7 @@ def containsCreditCardNumber(self, column, assertion=None, hint=None):
else getattr(self._Check, "containsCreditCardNumber$default$2")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.containsCreditCardNumber(column, assertion, hint)
return self

Expand All @@ -599,6 +634,7 @@ def containsEmail(self, column, assertion=None, hint=None):
else getattr(self._Check, "containsEmail$default$2")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.containsEmail(column, assertion, hint)
return self

Expand All @@ -617,6 +653,7 @@ def containsURL(self, column, assertion=None, hint=None):
else getattr(self._Check, "containsURL$default$2")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.containsURL(column, assertion, hint)
return self

Expand All @@ -636,6 +673,7 @@ def containsSocialSecurityNumber(self, column, assertion=None, hint=None):
else getattr(self._Check, "containsSocialSecurityNumber$default$2")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.containsSocialSecurityNumber(column, assertion, hint)
return self

Expand All @@ -657,6 +695,7 @@ def hasDataType(self, column, datatype: ConstrainableDataTypes, assertion=None,
else getattr(self._Check, "hasDataType$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.hasDataType(column, datatype_jvm, assertion, hint)
return self

Expand All @@ -676,6 +715,7 @@ def isNonNegative(self, column, assertion=None, hint=None):
else getattr(self._Check, "isNonNegative$default$2")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.isNonNegative(column, assertion_func, hint)
return self

Expand All @@ -694,6 +734,7 @@ def isPositive(self, column, assertion=None, hint=None):
else getattr(self._Check, "isPositive$default$2")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(column)
self._Check = self._Check.isPositive(column, assertion_func, hint)
return self

Expand All @@ -713,6 +754,8 @@ def isLessThan(self, columnA, columnB, assertion=None, hint=None):
else getattr(self._Check, "isLessThan$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(columnA)
self._handle_invalid_column(columnB)
self._Check = self._Check.isLessThan(columnA, columnB, assertion_func, hint)
return self

Expand All @@ -732,6 +775,8 @@ def isLessThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None):
else getattr(self._Check, "isLessThanOrEqualTo$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(columnA)
self._handle_invalid_column(columnB)
self._Check = self._Check.isLessThanOrEqualTo(columnA, columnB, assertion_func, hint)
return self

Expand All @@ -751,6 +796,8 @@ def isGreaterThan(self, columnA, columnB, assertion=None, hint=None):
else getattr(self._Check, "isGreaterThan$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(columnA)
self._handle_invalid_column(columnB)
self._Check = self._Check.isGreaterThan(columnA, columnB, assertion_func, hint)
return self

Expand All @@ -770,6 +817,8 @@ def isGreaterThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None):
else getattr(self._Check, "isGreaterThanOrEqualTo$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._handle_invalid_column(columnA)
self._handle_invalid_column(columnB)
self._Check = self._Check.isGreaterThanOrEqualTo(columnA, columnB, assertion_func, hint)
return self

Expand All @@ -785,6 +834,7 @@ def isContainedIn(self, column, allowed_values):
arr = self._spark_session.sparkContext._gateway.new_array(self._jvm.java.lang.String, len(allowed_values))
for i in range(len(allowed_values)):
arr[i] = allowed_values[i]
self._handle_invalid_column(column)
self._Check = self._Check.isContainedIn(column, arr)
return self

Expand Down
18 changes: 14 additions & 4 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class TestChecks(unittest.TestCase):
def setUpClass(cls):
cls.spark = setup_pyspark().appName("test-checkss-local").getOrCreate()
cls.sc = cls.spark.sparkContext
cls.df = cls.sc.parallelize(

df = cls.sc.parallelize(
[
Row(
a="foo",
Expand All @@ -31,6 +32,7 @@ def setUpClass(cls):
ssn="123-45-6789",
URL="http://[email protected]:8080",
boolean="true",
column_with_dot="sample",
),
Row(
a="bar",
Expand All @@ -46,6 +48,7 @@ def setUpClass(cls):
ssn="123456789",
URL="http://foo.com/(something)?after=parens",
boolean="false",
column_with_dot="sample",
),
Row(
a="baz",
Expand All @@ -61,9 +64,13 @@ def setUpClass(cls):
ssn="000-00-0000",
URL="http://[email protected]:8080",
boolean="true",
column_with_dot="sample",
),
]
).toDF()
df = df.withColumnRenamed("column_with_dot", "column.with.dot")
cls.df = df


@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -141,6 +148,7 @@ def hasDataType(self, column, datatype, assertion=None, hint=None):
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
df.show(truncate=False)
return df.select("constraint_status").collect()

def isComplete(self, column, hint=None):
Expand Down Expand Up @@ -585,9 +593,11 @@ def test_fail_containsSocialSecurityNumber(self):
def test_hasDataType(self):
self.assertEqual(self.hasDataType("a", ConstrainableDataTypes.String), [Row(constraint_status="Success")])
self.assertEqual(self.hasDataType("b", ConstrainableDataTypes.Numeric), [Row(constraint_status="Success")])
self.assertEqual(
self.hasDataType("boolean", ConstrainableDataTypes.Boolean), [Row(constraint_status="Success")]
)
self.assertEqual(self.hasDataType("boolean", ConstrainableDataTypes.Boolean), [Row(constraint_status="Success")])

def test_invalidColumnException(self):
with self.assertRaises(ValueError):
self.hasDataType("column.with.dot", ConstrainableDataTypes.String)

@pytest.mark.xfail(reason="@unittest.expectedFailure")
def test_fail_hasDataType(self):
Expand Down