From f283df840170fdb305324f042de303b03e7e0190 Mon Sep 17 00:00:00 2001 From: Hendrik Saly Date: Tue, 20 Aug 2024 23:55:59 +0200 Subject: [PATCH] Add percentile PPL function (#547) * percentile prototype Signed-off-by: Hendrik Saly * add tests for stats: percentile Signed-off-by: Kacper Trochimiak * Add PERCENTILE_APPROX Signed-off-by: Hendrik Saly * Add functions to readme Signed-off-by: Hendrik Saly * Add null checks Signed-off-by: Hendrik Saly * Fix tests, add tests Signed-off-by: Hendrik Saly --------- Signed-off-by: Hendrik Saly Signed-off-by: Kacper Trochimiak Co-authored-by: Kacper Trochimiak --- ...ntSparkPPLAggregationWithSpanITSuite.scala | 97 ++++++++++ .../FlintSparkPPLAggregationsITSuite.scala | 182 ++++++++++++++++++ ppl-spark-integration/README.md | 2 + .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../src/main/antlr4/OpenSearchPPLParser.g4 | 8 +- .../function/BuiltinFunctionName.java | 4 + .../sql/ppl/parser/AstExpressionBuilder.java | 8 + .../sql/ppl/utils/AggregatorTranslator.java | 34 ++++ ...ggregationQueriesTranslatorTestSuite.scala | 164 ++++++++++++++++ 9 files changed, 497 insertions(+), 3 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala index 1e80c94b4..3ffe05e81 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationWithSpanITSuite.scala @@ -351,4 +351,101 @@ class FlintSparkPPLAggregationWithSpanITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + /** + * | age_span | age_percentile | + * |:---------|---------------:| + * | 20 | 25 | + * | 30 | 30 | + * | 70 | 70 | + */ + test( + "create ppl simple age 60th percentile by span of interval of 10 years query with state filter test ") { + val frame = sql(s""" + | source = $testTable | where state != 'Quebec' | stats percentile(age, 60) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(25d, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val percentage = Literal(0.6) + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false), + "percentile(age, 60)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + /** + * | age_span | age_percentile | + * |:---------|---------------:| + * | 20 | 25 | + * | 30 | 30 | + * | 70 | 70 | + */ + test( + "create ppl simple age 60th percentile approx by span of interval of 10 years query with state filter test ") { + val frame = sql(s""" + | source = $testTable | where state != 'Quebec' | stats percentile_approx(age, 60) by span(age, 10) as age_span + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(70d, 70L), Row(30d, 30L), Row(25d, 20L)) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val percentage = Literal(0.6) + val stateField = UnresolvedAttribute("state") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val aggregateExpressions = + Alias( + UnresolvedFunction( + Seq("PERCENTILE_APPROX"), + Seq(ageField, percentage), + isDistinct = false), + "percentile_approx(age, 60)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val filterExpr = Not(EqualTo(stateField, Literal("Quebec"))) + val filterPlan = Filter(filterExpr, table) + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala index 4f9d4c64e..c638cd750 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAggregationsITSuite.scala @@ -615,4 +615,186 @@ class FlintSparkPPLAggregationsITSuite // Compare the two plans assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + + test("create ppl simple age 50th percentile ") { + val frame = sql(s""" + | source = $testTable| stats percentile(age, 50) + | """.stripMargin) + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(27.5)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val percentage = Literal("0.5") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false), + "percentile(age, 50)")() + val aggregatePlan = + Aggregate(Seq.empty, Seq(aggregateExpressions), table) + val expectedPlan = Project(star, aggregatePlan) + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl simple age 20th percentile group by country query test with sort") { + val frame = sql(s""" + | source = $testTable | stats percentile(age, 20) by country | sort country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(21d, "Canada"), Row(38d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert( + results.sorted.sameElements(expectedResults.sorted), + s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val percentage = Literal("0.2") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false), + "percentile(age, 20)")() + val productAlias = Alias(countryField, "country")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + // Compare the two plans + assert( + compareByString(expectedPlan) === compareByString(logicalPlan), + s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}") + } + + test("create ppl simple age 40th percentile group by country with state filter query test") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats percentile(age, 40) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(20d, "Canada"), Row(46d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val percentage = Literal("0.4") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false), + "percentile(age, 40)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test( + "create ppl simple age 40th percentile approx group by country with state filter query test") { + val frame = sql(s""" + | source = $testTable | where state != 'Ontario' | stats percentile_approx(age, 40) by country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(20d, "Canada"), Row(30d, "USA")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val stateField = UnresolvedAttribute("state") + val countryField = UnresolvedAttribute("country") + val ageField = UnresolvedAttribute("age") + val percentage = Literal("0.4") + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + + val groupByAttributes = Seq(Alias(countryField, "country")()) + val filterExpr = Not(EqualTo(stateField, Literal("Ontario"))) + val filterPlan = Filter(filterExpr, table) + val aggregateExpressions = + Alias( + UnresolvedFunction( + Seq("PERCENTILE_APPROX"), + Seq(ageField, percentage), + isDistinct = false), + "percentile_approx(age, 40)")() + val productAlias = Alias(countryField, "country")() + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + // Compare the two plans + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("create failing ppl percentile approx - due to too high percentage value test") { + val thrown = intercept[IllegalStateException] { + val frame = sql(s""" + | source = $testTable | stats percentile_approx(age, 200) by country + | """.stripMargin) + } + assert(thrown.getMessage === "Unsupported value 'percent': 200 (expected: >= 0 <= 100))") + } + + test("create failing ppl percentile approx - due to too low percentage value test") { + val thrown = intercept[IllegalStateException] { + val frame = sql(s""" + | source = $testTable | stats percentile_approx(age, -4) by country + | """.stripMargin) + } + assert(thrown.getMessage === "Unsupported value 'percent': -4 (expected: >= 0 <= 100))") + } } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index fed7e0038..3ea0a477a 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -272,6 +272,8 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | stats count(c) by b | head 5` - `source = table | stats stddev_samp(c)` - `source = table | stats stddev_pop(c)` + - `source = table | stats percentile(c, 90)` + - `source = table | stats percentile_approx(c, 99)` **Aggregations With Span** - `source = table | stats count(a) by span(a, 10) as a_span` diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index b1c988b28..d202f5ff6 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -196,6 +196,7 @@ VAR_POP: 'VAR_POP'; STDDEV_SAMP: 'STDDEV_SAMP'; STDDEV_POP: 'STDDEV_POP'; PERCENTILE: 'PERCENTILE'; +PERCENTILE_APPROX: 'PERCENTILE_APPROX'; TAKE: 'TAKE'; FIRST: 'FIRST'; LAST: 'LAST'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 76e65753b..f4065be6d 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -225,9 +225,10 @@ statsAggTerm // aggregation functions statsFunction - : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall - | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall - | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall + : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall + | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall + | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall + | percentileFunctionName = (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS valueExpression COMMA percent = integerLiteral RT_PRTHS # percentileFunctionCall ; statsFunctionName @@ -897,6 +898,7 @@ keywordsCanBeId | STDDEV_SAMP | STDDEV_POP | PERCENTILE + | PERCENTILE_APPROX | TAKE | FIRST | LAST diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f12648eb2..eb22164b9 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -174,6 +174,8 @@ public enum BuiltinFunctionName { TAKE(FunctionName.of("take")), // Not always an aggregation query NESTED(FunctionName.of("nested")), + PERCENTILE(FunctionName.of("percentile")), + PERCENTILE_APPROX(FunctionName.of("percentile_approx")), /** Text Functions. */ ASCII(FunctionName.of("ascii")), @@ -285,6 +287,8 @@ public FunctionName getName() { .put("stddev_pop", BuiltinFunctionName.STDDEV_POP) .put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP) .put("take", BuiltinFunctionName.TAKE) + .put("percentile", BuiltinFunctionName.PERCENTILE) + .put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX) .build(); public static Optional of(String str) { diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 71abb329f..352853398 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -172,6 +172,14 @@ public UnresolvedExpression visitPercentileAggFunction(OpenSearchPPLParser.Perce Collections.singletonList(new Argument("rank", (Literal) visit(ctx.value)))); } + @Override + public UnresolvedExpression visitPercentileFunctionCall(OpenSearchPPLParser.PercentileFunctionCallContext ctx) { + return new AggregateFunction( + ctx.percentileFunctionName.getText(), + visit(ctx.valueExpression()), + Collections.singletonList(new Argument("percent", (Literal) visit(ctx.percent)))); + } + /** * Eval function. */ diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java index eba60248d..244f71f09 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/AggregatorTranslator.java @@ -7,8 +7,16 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.Literal; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import java.util.List; + import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static scala.Option.empty; @@ -39,7 +47,33 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); case STDDEV_SAMP: return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false); + case PERCENTILE: + return new UnresolvedFunction(seq("PERCENTILE"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); + case PERCENTILE_APPROX: + return new UnresolvedFunction(seq("PERCENTILE_APPROX"), seq(arg, new Literal(getPercentDoubleValue(aggregateFunction), DataTypes.DoubleType)), aggregateFunction.getDistinct(), empty(),false); } throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName()); } + + private static double getPercentDoubleValue(AggregateFunction aggregateFunction) { + + List arguments = aggregateFunction.getArgList(); + + if (arguments == null || arguments.size() != 1) { + throw new IllegalStateException("Missing 'percent' argument"); + } + + org.opensearch.sql.ast.expression.Literal percentIntValue = ((Argument) aggregateFunction.getArgList().get(0)).getValue(); + + if (percentIntValue.getType() != DataType.INTEGER) { + throw new IllegalStateException("Unsupported datatype for 'percent': " + percentIntValue.getType() + " (expected: INTEGER)"); + } + + double percentDoubleValue = ((Integer) percentIntValue.getValue()) / 100d; + + if (percentDoubleValue < 0 || percentDoubleValue > 1) { + throw new IllegalStateException("Unsupported value 'percent': " + percentIntValue.getValue() + " (expected: >= 0 <= 100))"); + } + return percentDoubleValue; + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 61190294b..457faeaa3 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -594,4 +594,168 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test price 50th percentile group by product sorted") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats percentile(price, 50) by product | sort product", + false), + context) + val star = Seq(UnresolvedStar(None)) + val priceField = UnresolvedAttribute("price") + val productField = UnresolvedAttribute("product") + val percentage = Literal(0.5) + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("PERCENTILE"), Seq(priceField, percentage), isDistinct = false), + "percentile(price, 50)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(productField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test price 20th percentile with alias and filter") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table category = 'vegetable' | stats percentile(price, 20) as price_20_percentile", + false), + context) + val star = Seq(UnresolvedStar(None)) + val categoryField = UnresolvedAttribute("category") + val priceField = UnresolvedAttribute("price") + val percentage = Literal(0.2) + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("PERCENTILE"), Seq(priceField, percentage), isDistinct = false), + "price_20_percentile")()) + val filterExpr = EqualTo(categoryField, Literal("vegetable")) + val filterPlan = Filter(filterExpr, tableRelation) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test age 40th percentile by span of interval of 5 years query with sort ") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats percentile(age, 40) by span(age, 5) as age_span | sort age", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val percentage = Literal(0.4) + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("PERCENTILE"), Seq(ageField, percentage), isDistinct = false), + "percentile(age, 40)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(5))), Literal(5)), + "age_span")() + val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test sum number of flights by airport and calculate 30th percentile with aliases") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats sum(no_of_flights) as flights_count by airport | stats percentile(flights_count, 30) as percentile_30", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val numberOfFlightsField = UnresolvedAttribute("no_of_flights") + val airportField = UnresolvedAttribute("airport") + val percentage = Literal(0.3) + val flightsCountField = UnresolvedAttribute("flights_count") + val tableRelation = UnresolvedRelation(Seq("table")) + + val airportAlias = Alias(airportField, "airport")() + val sumAggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(numberOfFlightsField), isDistinct = false), + "flights_count")() + val sumGroupByAttributes = Seq(Alias(airportField, "airport")()) + val sumAggregatePlan = + Aggregate(sumGroupByAttributes, Seq(sumAggregateExpressions, airportAlias), tableRelation) + + val percentileAggregateExpressions = + Alias( + UnresolvedFunction( + Seq("PERCENTILE"), + Seq(flightsCountField, percentage), + isDistinct = false), + "percentile_30")() + val percentileAggregatePlan = + Aggregate(Seq(), Seq(percentileAggregateExpressions), sumAggregatePlan) + val expectedPlan = Project(star, percentileAggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test sum number of flights by airport and calculate 30th percentile approx with aliases") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table | stats sum(no_of_flights) as flights_count by airport | stats percentile_approx(flights_count, 30) as percentile_approx_30", + false), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val numberOfFlightsField = UnresolvedAttribute("no_of_flights") + val airportField = UnresolvedAttribute("airport") + val percentage = Literal(0.3) + val flightsCountField = UnresolvedAttribute("flights_count") + val tableRelation = UnresolvedRelation(Seq("table")) + + val airportAlias = Alias(airportField, "airport")() + val sumAggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(numberOfFlightsField), isDistinct = false), + "flights_count")() + val sumGroupByAttributes = Seq(Alias(airportField, "airport")()) + val sumAggregatePlan = + Aggregate(sumGroupByAttributes, Seq(sumAggregateExpressions, airportAlias), tableRelation) + + val percentileAggregateExpressions = + Alias( + UnresolvedFunction( + Seq("PERCENTILE_APPROX"), + Seq(flightsCountField, percentage), + isDistinct = false), + "percentile_approx_30")() + val percentileAggregatePlan = + Aggregate(Seq(), Seq(percentileAggregateExpressions), sumAggregatePlan) + val expectedPlan = Project(star, percentileAggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + }