From 6fc2010d57f89fb4f5310374acc2498095128a3f Mon Sep 17 00:00:00 2001 From: Hendrik Saly Date: Tue, 20 Aug 2024 19:54:56 +0200 Subject: [PATCH] Implement stddev_samp and stddev_pop ppl stats function (#549) * Implement stddev_samp and stddev_pop ppl stats function Signed-off-by: Hendrik Saly * add tests for stats: stdev_samp, stdev_pop Signed-off-by: Kacper Trochimiak * Fix scala style Signed-off-by: Hendrik Saly * Add functions to readme Signed-off-by: Hendrik Saly --------- Signed-off-by: Hendrik Saly Signed-off-by: Kacper Trochimiak Co-authored-by: Kacper Trochimiak --- ...ntSparkPPLAggregationWithSpanITSuite.scala | 91 ++++++- .../FlintSparkPPLAggregationsITSuite.scala | 235 ++++++++++++++++++ ppl-spark-integration/README.md | 2 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 2 + .../sql/ppl/utils/AggregatorTranslator.java | 4 + ...ggregationQueriesTranslatorTestSuite.scala | 222 +++++++++++++++++ 6 files changed, 555 insertions(+), 1 deletion(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index b3abf8438..1e80c94b4 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, Floor, Literal, Multiply, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, EqualTo, Floor, Literal, Multiply, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.streaming.StreamTest @@ -262,4 +262,93 @@ class FlintSparkPPLAggregationWithSpanITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + /** + * | age_span | age_stddev_samp | + * |:---------|-------------------:| + * | 20 | 3.5355339059327378 | + */ + test( + "create ppl age sample stddev by span of interval of 10 years query with country filter test ") { + val frame = sql(s""" + | source = $testTable | where country != 'USA' | stats stddev_samp(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(3.5355339059327378d, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(countryField, Literal("USA"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | age_stddev_pop | + * |:---------|---------------:| + * | 20 | 2.5 | + * | 30 | 0 | + */ + test( + "create ppl age population stddev by span of interval of 10 years query with state filter test ") { + val frame = sql(s""" + | source = $testTable | where state != 'California' | stats stddev_pop(age) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2.5d, 20L), Row(0d, 30L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("California"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index 3bc227e7d..4f9d4c64e 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -380,4 +380,239 @@ class FlintSparkPPLAggregationsITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + test("create ppl age sample stddev") { + val frame = sql(s""" + | source = $testTable| stats stddev_samp(age) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(22.86737122335374d)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl age sample stddev group by country query test with sort") { + val frame = sql(s""" + | source = $testTable | stats stddev_samp(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row(3.5355339059327378d, "Canada"), Row(28.284271247461902d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl age sample stddev group by country with state filter query test") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats stddev_samp(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(null, "Canada"), Row(28.284271247461902d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create ppl age population stddev") { + val frame = sql(s""" + | source = $testTable| stats stddev_pop(age) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(19.803724397193573d)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl age population stddev group by country query test with sort") { + val frame = sql(s""" + | source = $testTable | stats stddev_pop(age) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2.5d, "Canada"), Row(20d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl age population stddev group by country with state filter query test") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats stddev_pop(age) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(0d, "Canada"), Row(20d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 58558b2ce..67cecd48d 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -265,6 +265,8 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | where a < 50 | stats avg(c) ` - `source = table | stats max(c) by b` - `source = table | stats count(c) by b | head 5` + - `source = table | stats stddev_samp(c)` + - `source = table | stats stddev_pop(c)` **Aggregations With Span** - `source = table | stats count(a) by span(a, 10) as a_span` diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 19d480327..6f56550c9 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -234,6 +234,8 @@ statsFunctionName | SUM | MIN | MAX + | STDDEV_SAMP + | STDDEV_POP ; takeAggFunction diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index e15324cc0..eba60248d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -35,6 +35,10 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction return new UnresolvedFunction(seq("COUNT"), seq(arg),false, empty(),false); case SUM: return new UnresolvedFunction(seq("SUM"), seq(arg),false, empty(),false); + case STDDEV_POP: + return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + case STDDEV_SAMP: + return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index ba634cc1c..61190294b 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -372,4 +372,226 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test price sample stddev group by product sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_samp(price) by product | sort product", + false), + context) + val star = Seq(UnresolvedStar(None)) + val priceField = UnresolvedAttribute("price") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(priceField), isDistinct = false), + "stddev_samp(price)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(productField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test price sample stddev with alias and filter") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table category = 'vegetable' | stats stddev_samp(price) as dev_samp", + false), + context) + val star = Seq(UnresolvedStar(None)) + val categoryField = UnresolvedAttribute("category") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(priceField), isDistinct = false), + "dev_samp")()) + val filterExpr = EqualTo(categoryField, Literal("vegetable")) + val filterPlan = Filter(filterExpr, tableRelation) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test age sample stddev by span of interval of 5 years query with sort ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_samp(age) by span(age, 5) as age_span | sort age", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false), + "stddev_samp(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(5))), Literal(5)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test number of flights sample stddev by airport with alias and limit") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_samp(no_of_flights) as dev_samp_flights by airport | head 10", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val numberOfFlightsField = UnresolvedAttribute("no_of_flights") + val airportField = UnresolvedAttribute("airport") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(airportField, "airport")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(numberOfFlightsField), isDistinct = false), + "dev_samp_flights")() + val airportAlias = Alias(airportField, "airport")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, airportAlias), tableRelation) + val planWithLimit = GlobalLimit(Literal(10), LocalLimit(Literal(10), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test price population stddev group by product sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_pop(price) by product | sort product", + false), + context) + val star = Seq(UnresolvedStar(None)) + val priceField = UnresolvedAttribute("price") + val productField = UnresolvedAttribute("product") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(priceField), isDistinct = false), + "stddev_pop(price)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(productField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test price population stddev with alias and filter") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table category = 'vegetable' | stats stddev_pop(price) as dev_pop", + false), + context) + val star = Seq(UnresolvedStar(None)) + val categoryField = UnresolvedAttribute("category") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(priceField), isDistinct = false), + "dev_pop")()) + val filterExpr = EqualTo(categoryField, Literal("vegetable")) + val filterPlan = Filter(filterExpr, tableRelation) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test age population stddev by span of interval of 5 years query with sort ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_pop(age) by span(age, 5) as age_span | sort age", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false), + "stddev_pop(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(5))), Literal(5)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test number of flights population stddev by airport with alias and limit") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats stddev_pop(no_of_flights) as dev_pop_flights by airport | head 50", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val numberOfFlightsField = UnresolvedAttribute("no_of_flights") + val airportField = UnresolvedAttribute("airport") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(airportField, "airport")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("STDDEV_POP"), Seq(numberOfFlightsField), isDistinct = false), + "dev_pop_flights")() + val airportAlias = Alias(airportField, "airport")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, airportAlias), tableRelation) + val planWithLimit = GlobalLimit(Literal(50), LocalLimit(Literal(50), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + + comparePlans(expectedPlan, logPlan, false) + } + }