Skip to content

Commit

Permalink
Implement stddev_samp and stddev_pop ppl stats function (opensearch-p…
Browse files Browse the repository at this point in the history
…roject#549)

* Implement stddev_samp and stddev_pop ppl stats function

Signed-off-by: Hendrik Saly <[email protected]>

* add tests for stats: stdev_samp, stdev_pop

Signed-off-by: Kacper Trochimiak <[email protected]>

* Fix scala style

Signed-off-by: Hendrik Saly <[email protected]>

* Add functions to readme

Signed-off-by: Hendrik Saly <[email protected]>

---------

Signed-off-by: Hendrik Saly <[email protected]>
Signed-off-by: Kacper Trochimiak <[email protected]>
Co-authored-by: Kacper Trochimiak <[email protected]>
  • Loading branch information
salyh and kt-eliatra authored Aug 20, 2024
1 parent a91bf35 commit 6fc2010
Show file tree
Hide file tree
Showing 6 changed files with 555 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, Floor, Literal, Multiply, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, Divide, EqualTo, Floor, Literal, Multiply, Not, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -262,4 +262,93 @@ class FlintSparkPPLAggregationWithSpanITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* | age_span | age_stddev_samp |
* |:---------|-------------------:|
* | 20 | 3.5355339059327378 |
*/
test(
"create ppl age sample stddev by span of interval of 10 years query with country filter test ") {
val frame = sql(s"""
| source = $testTable | where country != 'USA' | stats stddev_samp(age) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(3.5355339059327378d, 20L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val countryField = UnresolvedAttribute("country")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val filterExpr = Not(EqualTo(countryField, Literal("USA")))
val filterPlan = Filter(filterExpr, table)
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

/**
* | age_span | age_stddev_pop |
* |:---------|---------------:|
* | 20 | 2.5 |
* | 30 | 0 |
*/
test(
"create ppl age population stddev by span of interval of 10 years query with state filter test ") {
val frame = sql(s"""
| source = $testTable | where state != 'California' | stats stddev_pop(age) by span(age, 10) as age_span
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(2.5d, 20L), Row(0d, 30L))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Long](_.getAs[Long](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val stateField = UnresolvedAttribute("state")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false),
"stddev_pop(age)")()
val span = Alias(
Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)),
"age_span")()
val filterExpr = Not(EqualTo(stateField, Literal("California")))
val filterPlan = Filter(filterExpr, table)
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,239 @@ class FlintSparkPPLAggregationsITSuite
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl age sample stddev") {
val frame = sql(s"""
| source = $testTable| stats stddev_samp(age)
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(22.86737122335374d))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val aggregatePlan =
Aggregate(Seq.empty, Seq(aggregateExpressions), table)
val expectedPlan = Project(star, aggregatePlan)
// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age sample stddev group by country query test with sort") {
val frame = sql(s"""
| source = $testTable | stats stddev_samp(age) by country | sort country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] =
Array(Row(3.5355339059327378d, "Canada"), Row(28.284271247461902d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age sample stddev group by country with state filter query test") {
val frame = sql(s"""
| source = $testTable | where state != 'Ontario' | stats stddev_samp(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(null, "Canada"), Row(28.284271247461902d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val filterExpr = Not(EqualTo(stateField, Literal("Ontario")))
val filterPlan = Filter(filterExpr, table)
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_SAMP"), Seq(ageField), isDistinct = false),
"stddev_samp(age)")()
val productAlias = Alias(countryField, "country")()
val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("create ppl age population stddev") {
val frame = sql(s"""
| source = $testTable| stats stddev_pop(age)
| """.stripMargin)
// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(19.803724397193573d))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")
// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false),
"stddev_pop(age)")()
val aggregatePlan =
Aggregate(Seq.empty, Seq(aggregateExpressions), table)
val expectedPlan = Project(star, aggregatePlan)
// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age population stddev group by country query test with sort") {
val frame = sql(s"""
| source = $testTable | stats stddev_pop(age) by country | sort country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(2.5d, "Canada"), Row(20d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(
results.sorted.sameElements(expectedResults.sorted),
s"Expected: ${expectedResults.mkString(", ")}, but got: ${results.mkString(", ")}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false),
"stddev_pop(age)")()
val productAlias = Alias(countryField, "country")()

val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), table)
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

// Compare the two plans
assert(
compareByString(expectedPlan) === compareByString(logicalPlan),
s"Expected plan: ${compareByString(expectedPlan)}, but got: ${compareByString(logicalPlan)}")
}

test("create ppl age population stddev group by country with state filter query test") {
val frame = sql(s"""
| source = $testTable | where state != 'Ontario' | stats stddev_pop(age) by country
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(Row(0d, "Canada"), Row(20d, "USA"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](1))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val star = Seq(UnresolvedStar(None))
val stateField = UnresolvedAttribute("state")
val countryField = UnresolvedAttribute("country")
val ageField = UnresolvedAttribute("age")
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))

val groupByAttributes = Seq(Alias(countryField, "country")())
val filterExpr = Not(EqualTo(stateField, Literal("Ontario")))
val filterPlan = Filter(filterExpr, table)
val aggregateExpressions =
Alias(
UnresolvedFunction(Seq("STDDEV_POP"), Seq(ageField), isDistinct = false),
"stddev_pop(age)")()
val productAlias = Alias(countryField, "country")()
val aggregatePlan =
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
2 changes: 2 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | where a < 50 | stats avg(c) `
- `source = table | stats max(c) by b`
- `source = table | stats count(c) by b | head 5`
- `source = table | stats stddev_samp(c)`
- `source = table | stats stddev_pop(c)`

**Aggregations With Span**
- `source = table | stats count(a) by span(a, 10) as a_span`
Expand Down
2 changes: 2 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ statsFunctionName
| SUM
| MIN
| MAX
| STDDEV_SAMP
| STDDEV_POP
;

takeAggFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ static Expression aggregator(org.opensearch.sql.ast.expression.AggregateFunction
return new UnresolvedFunction(seq("COUNT"), seq(arg),false, empty(),false);
case SUM:
return new UnresolvedFunction(seq("SUM"), seq(arg),false, empty(),false);
case STDDEV_POP:
return new UnresolvedFunction(seq("STDDEV_POP"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
case STDDEV_SAMP:
return new UnresolvedFunction(seq("STDDEV_SAMP"), seq(arg), aggregateFunction.getDistinct(), empty(),false);
}
throw new IllegalStateException("Not Supported value: " + aggregateFunction.getFuncName());
}
Expand Down
Loading

0 comments on commit 6fc2010

Please sign in to comment.