diff --git a/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala new file mode 100644 index 000000000..9eff89b6d --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala @@ -0,0 +1,63 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DataFrame + +case class ColumnCount() extends Analyzer[NumMatches, DoubleMetric] { + + val name = "ColumnCount" + val instance = "*" + val entity = Entity.Dataset + + /** + * Compute the state (sufficient statistics) from the data + * + * @param data the input dataframe + * @return the number of columns in the input + */ + override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[NumMatches] = { + if (filterCondition.isDefined) { + throw new IllegalArgumentException("ColumnCount does not accept a filter condition") + } else { + val numColumns = data.columns.size + Some(NumMatches(numColumns)) + } + } + + /** + * Compute the metric from the state (sufficient statistics) + * + * @param state the computed state from [[computeStateFrom]] + * @return a double metric indicating the number of columns for this analyzer + */ + override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = { + state + .map(v => Analyzers.metricFromValue(v.metricValue(), name, instance, entity)) + .getOrElse(Analyzers.metricFromEmpty(this, name, instance, entity)) + } + + /** + * Compute the metric from a failure - reports the exception thrown while trying to count columns + */ + override private[deequ] def toFailureMetric(failure: Exception): DoubleMetric = { + Analyzers.metricFromFailure(failure, name, instance, entity) + } +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala index e07e2d11f..edd4f8e97 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -26,6 +26,17 @@ import scala.util.Failure import scala.util.Success import scala.util.Try +case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] { + lazy val state = stateOrError.left.get + lazy val error = stateOrError.right.get + + override def sum(other: CustomSqlState): CustomSqlState = { + CustomSqlState(Left(state + other.state)) + } + + override def metricValue(): Double = state +} + case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { /** * Compute the state (sufficient statistics) from the data diff --git a/src/main/scala/com/amazon/deequ/analyzers/Size.scala b/src/main/scala/com/amazon/deequ/analyzers/Size.scala index a5080084a..c56083abe 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Size.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Size.scala @@ -20,17 +20,6 @@ import com.amazon.deequ.metrics.Entity import org.apache.spark.sql.{Column, Row} import Analyzers._ -case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] { - lazy val state = stateOrError.left.get - lazy val error = stateOrError.right.get - - override def sum(other: CustomSqlState): CustomSqlState = { - CustomSqlState(Left(state + other.state)) - } - - override def metricValue(): Double = state -} - case class NumMatches(numMatches: Long) extends DoubleValuedState[NumMatches] { override def sum(other: NumMatches): NumMatches = { diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index 2537922be..1e1048921 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -127,6 +127,13 @@ case class Check( addFilterableConstraint { filter => Constraint.sizeConstraint(assertion, filter, hint) } } + def hasColumnCount(assertion: Long => Boolean, hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + addFilterableConstraint { + filter => Constraint.columnCountConstraint(assertion, hint) + } + } + /** * Creates a constraint that asserts on a column completion. * diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index c0e6e9b9d..e289b3859 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -128,6 +128,18 @@ object Constraint { new NamedConstraint(constraint, s"SizeConstraint($size)") } + def columnCountConstraint(assertion: Long => Boolean, hint: Option[String] = None): Constraint = { + val colCount = ColumnCount() + fromAnalyzer(colCount, assertion, hint) + } + + + def fromAnalyzer(colCount: ColumnCount, assertion: Long => Boolean, hint: Option[String]): Constraint = { + val constraint = AnalysisBasedConstraint[NumMatches, Double, Long](colCount, assertion, Some(_.toLong), hint) + + new NamedConstraint(constraint, name = s"ColumnCountConstraint($colCount)") + } + /** * Runs Histogram analysis on the given column and executes the assertion * diff --git a/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala b/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala index 20017fa71..728ce866c 100644 --- a/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala +++ b/src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala @@ -22,7 +22,8 @@ import com.amazon.deequ.analyzers.{Distance, QuantileNonSample} import com.amazon.deequ.metrics.BucketValue import com.amazon.deequ.utils.FixtureSupport import org.scalatest.WordSpec -import com.amazon.deequ.metrics.{BucketValue} +import com.amazon.deequ.metrics.BucketValue +import org.scalactic.Tolerance.convertNumericToPlusOrMinusWrapper class KLLDistanceTest extends WordSpec with SparkContextSpec with FixtureSupport{ @@ -88,7 +89,7 @@ class KLLDistanceTest extends WordSpec with SparkContextSpec val sample2 = scala.collection.mutable.Map( "a" -> 22L, "b" -> 20L, "c" -> 25L, "d" -> 12L, "e" -> 13L, "f" -> 15L) val distance = Distance.categoricalDistance(sample1, sample2, method = LInfinityMethod(alpha = Some(0.003))) - assert(distance == 0.2726338046550349) + assert(distance === 0.2726338046550349 +- 1E-14) } "Categorial distance should compute correct linf_robust with different alpha value .1" in { diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index df13ea901..146579e8e 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -61,6 +61,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec val checkToSucceed = Check(CheckLevel.Error, "group-1") .isComplete("att1") + .hasColumnCount(_ == 3) .hasCompleteness("att1", _ == 1.0) val checkToErrorOut = Check(CheckLevel.Error, "group-2-E") diff --git a/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala new file mode 100644 index 000000000..00df2758c --- /dev/null +++ b/src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala @@ -0,0 +1,45 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.utils.FixtureSupport +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Failure +import scala.util.Success + +class ColumnCountTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + "ColumnCount" should { + "return column count for a dataset" in withSparkSession { session => + val data = getDfWithStringColumns(session) + val colCount = ColumnCount() + + val state = colCount.computeStateFrom(data) + state.isDefined shouldBe true + state.get.metricValue() shouldBe 5.0 + + val metric = colCount.computeMetricFrom(state) + metric.fullColumn shouldBe None + metric.value shouldBe Success(5.0) + } + } +}