diff --git a/pydeequ/checks.py b/pydeequ/checks.py index fb13b3a..50f5d90 100644 --- a/pydeequ/checks.py +++ b/pydeequ/checks.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- from enum import Enum - from pyspark.sql import SparkSession from pydeequ.scala_utils import ScalaFunction1, to_scala_seq @@ -100,6 +99,14 @@ def __init__(self, spark_session: SparkSession, level: CheckLevel, description: for constraint in self.constraints: self.addConstraint(constraint) + @staticmethod + def _handle_invalid_column(column): + """ + :param column: column to which constraint is to be applied + """ + if "." in column: + raise ValueError("column name cannot contain .") + def addConstraints(self, constraints: list): self.constraints.extend(constraints) for constraint in constraints: @@ -144,6 +151,7 @@ def isComplete(self, column, hint=None): :return: isComplete self:A Check.scala object that asserts on a column completion. """ hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.isComplete(column, hint) return self @@ -160,6 +168,7 @@ def hasCompleteness(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasCompleteness(column, assertion_func, hint) print(self) return self @@ -172,6 +181,7 @@ def areComplete(self, columns, hint=None): :return: areComplete self: A Check.scala object that asserts completion in the columns. """ hint = self._jvm.scala.Option.apply(hint) + [self._handle_invalid_column(column) for column in columns] columns_seq = to_scala_seq(self._jvm, columns) self._Check = self._Check.areComplete(columns_seq, hint) return self @@ -184,6 +194,7 @@ def haveCompleteness(self, columns, assertion, hint=None): :param str hint: A hint that states why a constraint could have failed. :return: haveCompleteness self: A Check.scala object that implements the assertion on the columns. """ + [self._handle_invalid_column(column) for column in columns] columns_seq = to_scala_seq(self._jvm, columns) assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) @@ -198,6 +209,7 @@ def areAnyComplete(self, columns, hint=None): :return: areAnyComplete self: A Check.scala object that asserts completion in the columns. """ hint = self._jvm.scala.Option.apply(hint) + [self._handle_invalid_column(column) for column in columns] columns_seq = to_scala_seq(self._jvm, columns) self._Check = self._Check.areAnyComplete(columns_seq, hint) return self @@ -210,6 +222,7 @@ def haveAnyCompleteness(self, columns, assertion, hint=None): :param str hint: A hint that states why a constraint could have failed. :return: haveAnyCompleteness self: A Check.scala object that asserts completion in the columns. """ + [self._handle_invalid_column(column) for column in columns] columns_seq = to_scala_seq(self._jvm, columns) assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) @@ -225,6 +238,7 @@ def isUnique(self, column, hint=None): :return: isUnique self: A Check.scala object that asserts uniqueness in the column. """ hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.isUnique(column, hint) return self @@ -255,6 +269,7 @@ def hasUniqueness(self, columns, assertion, hint=None): :return: hasUniqueness self: A Check object that asserts uniqueness in the columns. """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) + [self._handle_invalid_column(column) for column in columns] columns_seq = to_scala_seq(self._jvm, columns) hint = self._jvm.scala.Option.apply(hint) self._Check = self._Check.hasUniqueness(columns_seq, assertion_func, hint) @@ -272,6 +287,7 @@ def hasDistinctness(self, columns, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + [self._handle_invalid_column(column) for column in columns] columns_seq = to_scala_seq(self._jvm, columns) self._Check = self._Check.hasDistinctness(columns_seq, assertion_func, hint) return self @@ -287,6 +303,7 @@ def hasUniqueValueRatio(self, columns, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + [self._handle_invalid_column(column) for column in columns] columns_seq = to_scala_seq(self._jvm, columns) self._Check = self._Check.hasUniqueValueRatio(columns_seq, assertion_func, hint) return self @@ -303,6 +320,7 @@ def hasNumberOfDistinctValues(self, column, assertion, binningUdf, maxBins, hint """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasNumberOfDistinctValues(column, assertion_func, binningUdf, maxBins, hint) return self @@ -318,6 +336,7 @@ def hasHistogramValues(self, column, assertion, binningUdf, maxBins, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasHistogramValues(column, assertion_func, binningUdf, maxBins, hint) return self @@ -333,6 +352,7 @@ def kllSketchSatisfies(self, column, assertion, kllParameters=None, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) params = self._jvm.scala.Option.apply(kllParameters._param if kllParameters else None) self._Check = self._Check.kllSketchSatisfies(column, assertion_func, params, hint) return self @@ -363,6 +383,7 @@ def hasEntropy(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasEntropy(column, assertion_func, hint) return self @@ -379,6 +400,8 @@ def hasMutualInformation(self, columnA, columnB, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(columnA) + self._handle_invalid_column(columnB) self._Check = self._Check.hasMutualInformation(columnA, columnB, assertion_func, hint) return self @@ -394,6 +417,7 @@ def hasApproxQuantile(self, column, quantile, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasApproxQuantile(column, float(quantile), assertion_func, hint) return self @@ -409,6 +433,7 @@ def hasMinLength(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasMinLength(column, assertion_func, hint) return self @@ -424,6 +449,7 @@ def hasMaxLength(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasMaxLength(column, assertion_func, hint) return self @@ -440,6 +466,7 @@ def hasMin(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasMin(column, assertion_func, hint) return self @@ -456,6 +483,7 @@ def hasMax(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasMax(column, assertion_func, hint) return self @@ -470,6 +498,7 @@ def hasMean(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasMean(column, assertion_func, hint) return self @@ -484,6 +513,7 @@ def hasSum(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasSum(column, assertion_func, hint) return self @@ -498,6 +528,7 @@ def hasStandardDeviation(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasStandardDeviation(column, assertion_func, hint) return self @@ -512,6 +543,7 @@ def hasApproxCountDistinct(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasApproxCountDistinct(column, assertion_func, hint) return self @@ -527,6 +559,8 @@ def hasCorrelation(self, columnA, columnB, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(columnA) + self._handle_invalid_column(columnB) self._Check = self._Check.hasCorrelation(columnA, columnB, assertion_func, hint) return self @@ -581,6 +615,7 @@ def containsCreditCardNumber(self, column, assertion=None, hint=None): else getattr(self._Check, "containsCreditCardNumber$default$2")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.containsCreditCardNumber(column, assertion, hint) return self @@ -599,6 +634,7 @@ def containsEmail(self, column, assertion=None, hint=None): else getattr(self._Check, "containsEmail$default$2")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.containsEmail(column, assertion, hint) return self @@ -617,6 +653,7 @@ def containsURL(self, column, assertion=None, hint=None): else getattr(self._Check, "containsURL$default$2")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.containsURL(column, assertion, hint) return self @@ -636,6 +673,7 @@ def containsSocialSecurityNumber(self, column, assertion=None, hint=None): else getattr(self._Check, "containsSocialSecurityNumber$default$2")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.containsSocialSecurityNumber(column, assertion, hint) return self @@ -657,6 +695,7 @@ def hasDataType(self, column, datatype: ConstrainableDataTypes, assertion=None, else getattr(self._Check, "hasDataType$default$3")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.hasDataType(column, datatype_jvm, assertion, hint) return self @@ -676,6 +715,7 @@ def isNonNegative(self, column, assertion=None, hint=None): else getattr(self._Check, "isNonNegative$default$2")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.isNonNegative(column, assertion_func, hint) return self @@ -694,6 +734,7 @@ def isPositive(self, column, assertion=None, hint=None): else getattr(self._Check, "isPositive$default$2")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(column) self._Check = self._Check.isPositive(column, assertion_func, hint) return self @@ -713,6 +754,8 @@ def isLessThan(self, columnA, columnB, assertion=None, hint=None): else getattr(self._Check, "isLessThan$default$3")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(columnA) + self._handle_invalid_column(columnB) self._Check = self._Check.isLessThan(columnA, columnB, assertion_func, hint) return self @@ -732,6 +775,8 @@ def isLessThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None): else getattr(self._Check, "isLessThanOrEqualTo$default$3")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(columnA) + self._handle_invalid_column(columnB) self._Check = self._Check.isLessThanOrEqualTo(columnA, columnB, assertion_func, hint) return self @@ -751,6 +796,8 @@ def isGreaterThan(self, columnA, columnB, assertion=None, hint=None): else getattr(self._Check, "isGreaterThan$default$3")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(columnA) + self._handle_invalid_column(columnB) self._Check = self._Check.isGreaterThan(columnA, columnB, assertion_func, hint) return self @@ -770,6 +817,8 @@ def isGreaterThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None): else getattr(self._Check, "isGreaterThanOrEqualTo$default$3")() ) hint = self._jvm.scala.Option.apply(hint) + self._handle_invalid_column(columnA) + self._handle_invalid_column(columnB) self._Check = self._Check.isGreaterThanOrEqualTo(columnA, columnB, assertion_func, hint) return self @@ -785,6 +834,7 @@ def isContainedIn(self, column, allowed_values): arr = self._spark_session.sparkContext._gateway.new_array(self._jvm.java.lang.String, len(allowed_values)) for i in range(len(allowed_values)): arr[i] = allowed_values[i] + self._handle_invalid_column(column) self._Check = self._Check.isContainedIn(column, arr) return self diff --git a/tests/test_checks.py b/tests/test_checks.py index f33f266..eac2022 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -15,7 +15,8 @@ class TestChecks(unittest.TestCase): def setUpClass(cls): cls.spark = setup_pyspark().appName("test-checkss-local").getOrCreate() cls.sc = cls.spark.sparkContext - cls.df = cls.sc.parallelize( + + df = cls.sc.parallelize( [ Row( a="foo", @@ -31,6 +32,7 @@ def setUpClass(cls): ssn="123-45-6789", URL="http://userid@example.com:8080", boolean="true", + column_with_dot="sample", ), Row( a="bar", @@ -46,6 +48,7 @@ def setUpClass(cls): ssn="123456789", URL="http://foo.com/(something)?after=parens", boolean="false", + column_with_dot="sample", ), Row( a="baz", @@ -61,9 +64,13 @@ def setUpClass(cls): ssn="000-00-0000", URL="http://userid@example.com:8080", boolean="true", + column_with_dot="sample", ), ] ).toDF() + df = df.withColumnRenamed("column_with_dot", "column.with.dot") + cls.df = df + @classmethod def tearDownClass(cls): @@ -141,6 +148,7 @@ def hasDataType(self, column, datatype, assertion=None, hint=None): ) df = VerificationResult.checkResultsAsDataFrame(self.spark, result) + df.show(truncate=False) return df.select("constraint_status").collect() def isComplete(self, column, hint=None): @@ -585,9 +593,11 @@ def test_fail_containsSocialSecurityNumber(self): def test_hasDataType(self): self.assertEqual(self.hasDataType("a", ConstrainableDataTypes.String), [Row(constraint_status="Success")]) self.assertEqual(self.hasDataType("b", ConstrainableDataTypes.Numeric), [Row(constraint_status="Success")]) - self.assertEqual( - self.hasDataType("boolean", ConstrainableDataTypes.Boolean), [Row(constraint_status="Success")] - ) + self.assertEqual(self.hasDataType("boolean", ConstrainableDataTypes.Boolean), [Row(constraint_status="Success")]) + + def test_invalidColumnException(self): + with self.assertRaises(ValueError): + self.hasDataType("column.with.dot", ConstrainableDataTypes.String) @pytest.mark.xfail(reason="@unittest.expectedFailure") def test_fail_hasDataType(self):