From ae8db99724742c2e9e66255d02e3aed08a9d6fe3 Mon Sep 17 00:00:00 2001 From: Rahul Sharma Date: Thu, 11 Apr 2024 10:52:15 -0400 Subject: [PATCH] Updated Spark 3.3 dependency This commit updates the Spark 3.3 dependency of Deequ. There are some breaking changes to the Scala APIs, from a Py4J perspective. In order to work around that, we use the Spark version to switch between the updated API and the old API. This is not sustainable and will be revisited in a future PR, or via a different release mechanism. The issue is that we have multiple branches for multiple Spark versions in Deequ, but only one branch in PyDeequ. The changes were verified by running the tests in Docker against Spark version 3.3. The docker file was also updated so that it copies over the pyproject.toml file and installs dependencies in a separate layer, before the code is copied. This allows for fast iteration of the code, without the need to install dependencies every time the docker image is built. --- Dockerfile | 10 ++++--- pydeequ/analyzers.py | 62 +++++++++++++++++++++++++++++++++++++------- pydeequ/checks.py | 29 ++++++++++++++++++--- pydeequ/configs.py | 3 ++- pydeequ/profiles.py | 2 ++ 5 files changed, 87 insertions(+), 19 deletions(-) diff --git a/Dockerfile b/Dockerfile index a7a236a..bdd9099 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,12 +16,14 @@ RUN pip3 --version RUN java -version RUN pip install poetry==1.7.1 -COPY . /python-deequ +RUN mkdir python-deequ +COPY pyproject.toml /python-deequ +COPY poetry.lock /python-deequ WORKDIR python-deequ -RUN poetry lock --no-update -RUN poetry install -RUN poetry add pyspark==3.3 +RUN poetry install -vvv +RUN poetry add pyspark==3.3 -vvv ENV SPARK_VERSION=3.3 +COPY . /python-deequ CMD poetry run python -m pytest -s tests diff --git a/pydeequ/analyzers.py b/pydeequ/analyzers.py index efd1361..4289094 100644 --- a/pydeequ/analyzers.py +++ b/pydeequ/analyzers.py @@ -10,7 +10,7 @@ from pydeequ.repository import MetricsRepository, ResultKey from enum import Enum from pydeequ.scala_utils import to_scala_seq - +from pydeequ.configs import SPARK_VERSION class _AnalyzerObject: """ @@ -303,7 +303,19 @@ def _analyzer_jvm(self): :return self """ - return self._deequAnalyzers.Compliance(self.instance, self.predicate, self._jvm.scala.Option.apply(self.where)) + if SPARK_VERSION == "3.3": + return self._deequAnalyzers.Compliance( + self.instance, + self.predicate, + self._jvm.scala.Option.apply(self.where), + self._jvm.scala.collection.Seq.empty() + ) + else: + return self._deequAnalyzers.Compliance( + self.instance, + self.predicate, + self._jvm.scala.Option.apply(self.where) + ) class Correlation(_AnalyzerObject): @@ -457,12 +469,22 @@ def _analyzer_jvm(self): """ if not self.maxDetailBins: self.maxDetailBins = getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$3")() - return self._deequAnalyzers.Histogram( - self.column, - self._jvm.scala.Option.apply(self.binningUdf), - self.maxDetailBins, - self._jvm.scala.Option.apply(self.where), - ) + if SPARK_VERSION == "3.3": + return self._deequAnalyzers.Histogram( + self.column, + self._jvm.scala.Option.apply(self.binningUdf), + self.maxDetailBins, + self._jvm.scala.Option.apply(self.where), + getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$5")(), + getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$6")() + ) + else: + return self._deequAnalyzers.Histogram( + self.column, + self._jvm.scala.Option.apply(self.binningUdf), + self.maxDetailBins, + self._jvm.scala.Option.apply(self.where) + ) class KLLParameters: @@ -553,7 +575,17 @@ def _analyzer_jvm(self): :return self """ - return self._deequAnalyzers.MaxLength(self.column, self._jvm.scala.Option.apply(self.where)) + if SPARK_VERSION == "3.3": + return self._deequAnalyzers.MaxLength( + self.column, + self._jvm.scala.Option.apply(self.where), + self._jvm.scala.Option.apply(None) + ) + else: + return self._deequAnalyzers.MaxLength( + self.column, + self._jvm.scala.Option.apply(self.where) + ) class Mean(_AnalyzerObject): @@ -619,7 +651,17 @@ def _analyzer_jvm(self): :return self """ - return self._deequAnalyzers.MinLength(self.column, self._jvm.scala.Option.apply(self.where)) + if SPARK_VERSION == "3.3": + return self._deequAnalyzers.MinLength( + self.column, + self._jvm.scala.Option.apply(self.where), + self._jvm.scala.Option.apply(None) + ) + else: + return self._deequAnalyzers.MinLength( + self.column, + self._jvm.scala.Option.apply(self.where) + ) class MutualInformation(_AnalyzerObject): diff --git a/pydeequ/checks.py b/pydeequ/checks.py index abf94d0..ebfb4ee 100644 --- a/pydeequ/checks.py +++ b/pydeequ/checks.py @@ -6,7 +6,7 @@ from pydeequ.check_functions import is_one from pydeequ.scala_utils import ScalaFunction1, to_scala_seq - +from pydeequ.configs import SPARK_VERSION # TODO implement custom assertions # TODO implement all methods without outside class dependencies @@ -418,7 +418,11 @@ def hasMinLength(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) - self._Check = self._Check.hasMinLength(column, assertion_func, hint) + if SPARK_VERSION == "3.3": + self._Check = self._Check.hasMinLength(column, assertion_func, hint, self._jvm.scala.Option.apply(None)) + else: + self._Check = self._Check.hasMinLength(column, assertion_func) + return self def hasMaxLength(self, column, assertion, hint=None): @@ -433,7 +437,10 @@ def hasMaxLength(self, column, assertion, hint=None): """ assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) hint = self._jvm.scala.Option.apply(hint) - self._Check = self._Check.hasMaxLength(column, assertion_func, hint) + if SPARK_VERSION == "3.3": + self._Check = self._Check.hasMaxLength(column, assertion_func, hint, self._jvm.scala.Option.apply(None)) + else: + self._Check = self._Check.hasMaxLength(column, assertion_func, hint) return self def hasMin(self, column, assertion, hint=None): @@ -558,7 +565,21 @@ def satisfies(self, columnCondition, constraintName, assertion=None, hint=None): else getattr(self._Check, "satisfies$default$3")() ) hint = self._jvm.scala.Option.apply(hint) - self._Check = self._Check.satisfies(columnCondition, constraintName, assertion_func, hint) + if SPARK_VERSION == "3.3": + self._Check = self._Check.satisfies( + columnCondition, + constraintName, + assertion_func, + hint, + self._jvm.scala.collection.Seq.empty() + ) + else: + self._Check = self._Check.satisfies( + columnCondition, + constraintName, + assertion_func, + hint + ) return self def hasPattern(self, column, pattern, assertion=None, name=None, hint=None): diff --git a/pydeequ/configs.py b/pydeequ/configs.py index c3c885d..49cb277 100644 --- a/pydeequ/configs.py +++ b/pydeequ/configs.py @@ -5,7 +5,7 @@ SPARK_TO_DEEQU_COORD_MAPPING = { - "3.3": "com.amazon.deequ:deequ:2.0.3-spark-3.3", + "3.3": "com.amazon.deequ:deequ:2.0.4-spark-3.3", "3.2": "com.amazon.deequ:deequ:2.0.1-spark-3.2", "3.1": "com.amazon.deequ:deequ:2.0.0-spark-3.1", "3.0": "com.amazon.deequ:deequ:1.2.2-spark-3.0", @@ -40,5 +40,6 @@ def _get_deequ_maven_config(): ) +SPARK_VERSION = _get_spark_version() DEEQU_MAVEN_COORD = _get_deequ_maven_config() IS_DEEQU_V1 = re.search("com\.amazon\.deequ\:deequ\:1.*", DEEQU_MAVEN_COORD) is not None diff --git a/pydeequ/profiles.py b/pydeequ/profiles.py index a4a2056..fbbfd84 100644 --- a/pydeequ/profiles.py +++ b/pydeequ/profiles.py @@ -241,7 +241,9 @@ def __init__(self, spark_session: SparkSession): self._profiles = [] self.columnProfileClasses = { "StandardColumnProfile": StandardColumnProfile, + "StringColumnProfile": StandardColumnProfile, "NumericColumnProfile": NumericColumnProfile, + } def _columnProfilesFromColumnRunBuilderRun(self, run):