Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 0.5-nexus] Top & Rare PPL commands support #584

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLTopAndRareITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedMultiRowAddressTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("create ppl rare address field query test") {
val frame = sql(s"""
| source = $testTable| rare address
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRow = Row(1, "Vancouver")
assert(
results.head == expectedRow,
s"Expected least frequent result to be $expectedRow, but got ${results.head}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val addressField = UnresolvedAttribute("address")
val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val aggregateExpressions = Seq(
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false),
"count(address)")(),
addressField)
val aggregatePlan =
Aggregate(
Seq(addressField),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Descending)),
global = true,
aggregatePlan)
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl rare address by age field query test") {
val frame = sql(s"""
| source = $testTable| rare address by age
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 5)

val expectedRow = Row(1, "Vancouver", 60)
assert(
results.head == expectedRow,
s"Expected least frequent result to be $expectedRow, but got ${results.head}")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val addressField = UnresolvedAttribute("address")
val ageField = UnresolvedAttribute("age")
val ageAlias = Alias(ageField, "age")()

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val countExpr = Alias(
UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false),
"count(address)")()

val aggregateExpressions = Seq(countExpr, addressField, ageAlias)
val aggregatePlan =
Aggregate(
Seq(addressField, ageAlias),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Descending)),
global = true,
aggregatePlan)

val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl top address field query test") {
val frame = sql(s"""
| source = $testTable| top address
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRows = Set(Row(2, "Portland"), Row(2, "Seattle"))
val actualRows = results.take(2).toSet

// Compare the sets
assert(
actualRows == expectedRows,
s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val addressField = UnresolvedAttribute("address")
val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))

val aggregateExpressions = Seq(
Alias(
UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false),
"count(address)")(),
addressField)
val aggregatePlan =
Aggregate(
Seq(addressField),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("address"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(projectList, sortedPlan)
comparePlans(expectedPlan, logicalPlan, false)
}

test("create ppl top 3 countries by occupation field query test") {
val newTestTable = "spark_catalog.default.new_flint_ppl_test"
createOccupationTable(newTestTable)

val frame = sql(s"""
| source = $newTestTable| top 3 country by occupation
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 3)

val expectedRows = Set(
Row(1, "Canada", "Doctor"),
Row(1, "Canada", "Scientist"),
Row(1, "Canada", "Unemployed"))
val actualRows = results.take(3).toSet

// Compare the sets
assert(
actualRows == expectedRows,
s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows")

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical

val countryField = UnresolvedAttribute("country")
val occupationField = UnresolvedAttribute("occupation")
val occupationFieldAlias = Alias(occupationField, "occupation")()

val countExpr = Alias(
UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false),
"count(country)")()
val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias)
val aggregatePlan =
Aggregate(
Seq(countryField, occupationFieldAlias),
aggregateExpressions,
UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test")))

val sortedPlan: LogicalPlan =
Sort(
Seq(SortOrder(UnresolvedAttribute("country"), Ascending)),
global = true,
aggregatePlan)

val planWithLimit =
GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
comparePlans(expectedPlan, logicalPlan, false)
}
}
12 changes: 10 additions & 2 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
@@ -278,7 +278,6 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | stats sum(productsAmount) by span(transactionDate, 1w) as age_date, productId`

**Dedup**

- `source = table | dedup a | fields a,b,c`
- `source = table | dedup a,b | fields a,b,c`
- `source = table | dedup a keepempty=true | fields a,b,c`
@@ -290,8 +289,17 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | dedup 1 a consecutive=true| fields a,b,c` (Unsupported)
- `source = table | dedup 2 a | fields a,b,c` (Unsupported)

**Rare**
- `source=accounts | rare gender`
- `source=accounts | rare age by gender`

**Top**
- `source=accounts | top gender`
- `source=accounts | top 1 gender`
- `source=accounts | top 1 age by gender`


For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst)
> For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst)

---

Original file line number Diff line number Diff line change
@@ -38,6 +38,8 @@ commands
| dedupCommand
| sortCommand
| headCommand
| topCommand
| rareCommand
| evalCommand
;

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.Collections;
import java.util.List;

/** Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions in queries. */
public class RareAggregation extends Aggregation {
/** Aggregation Constructor without span and argument. */
public RareAggregation(
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/** Logical plan node of Top (Aggregation) command, the interface for building aggregation actions in queries. */
public class TopAggregation extends Aggregation {
private final Optional<Literal> results;

/** Aggregation Constructor without span and argument. */
public TopAggregation(
Optional<Literal> results,
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
this.results = results;
}

public Optional<Literal> getResults() {
return results;
}
}
Original file line number Diff line number Diff line change
@@ -9,9 +9,12 @@
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.Ascending$;
import org.apache.spark.sql.catalyst.expressions.Descending$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$;
@@ -59,9 +62,11 @@
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.RareAggregation;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.TopAggregation;
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
@@ -176,20 +181,39 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
node.getChild().get(0).accept(this, context);
List<Expression> aggsExpList = visitExpressionList(node.getAggExprList(), context);
List<Expression> groupExpList = visitExpressionList(node.getGroupExprList(), context);

if (!groupExpList.isEmpty()) {
//add group by fields to context
context.getGroupingParseExpressions().addAll(groupExpList);
}

UnresolvedExpression span = node.getSpan();
if (!Objects.isNull(span)) {
span.accept(this, context);
//add span's group alias field (most recent added expression)
context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek());
}
// build the aggregation logical step
return extractedAggregation(context);
LogicalPlan logicalPlan = extractedAggregation(context);

// set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc)
List<SortDirection> sortDirections = new ArrayList<>();
sortDirections.add(node instanceof RareAggregation ? Descending$.MODULE$ : Ascending$.MODULE$);

if (!node.getSortExprList().isEmpty()) {
visitExpressionList(node.getSortExprList(), context);
Seq<SortOrder> sortElements = context.retainAllNamedParseExpressions(exp ->
new SortOrder(exp,
sortDirections.get(0),
sortDirections.get(0).defaultNullOrdering(),
seq(new ArrayList<Expression>())));
context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, logicalPlan));
}
//visit TopAggregation results limit
if((node instanceof TopAggregation) && ((TopAggregation) node).getResults().isPresent()) {
context.apply(p ->(LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal(
((TopAggregation) node).getResults().get().getValue(), org.apache.spark.sql.types.DataTypes.IntegerType), p));
}
return logicalPlan;
}

private static LogicalPlan extractedAggregation(CatalystPlanContext context) {
Loading